Interpretable Insights in Healthcare: Explaining Model Decisions in Image and Tabular Data Analysis with SHAP¶
1. Background and Relevance¶
Over the past few years, deep learning and machine learning models have generated considerable interest in critical areas such as healthcare, finance, criminal justice, etc. However, the lack of explainability of such models is generally considered a major barrier to adopting automated AI-based decision-making processes (Chander et al., 2018).
As a result, there is a growing demand for transparency and accountability in the decision-making processes of such systems in these critical domains. Particularly in the medical field, explainable AI plays a pivotal role in assisting medical practitioners and healthcare professionals to make decisions that are explainable, transparent, understandable, and accountable (Holzinger et al.,2017). Considering that such decisions significantly impact the further treatment process and, thus, in some cases, also on the patient's well-being, explainability becomes even more critical when severe misclassifications occur.
Image Data¶
Doctors often manually examine thousands of images from a single medical procedure in conventional imaging diagnostics. This process is incredibly time-consuming and relies on the doctor’s prolonged attention. However, by employing deep learning and techniques for AI-based extraction of information from images, this process can be significantly streamlined to become much faster, more convenient, and more resource-efficient. After all, the transparent and explainable decisions or results generated by such a model would only need to be validated by a doctor (Meske & Bunde, 2020; Knapič et al., 2021).
Tabular Data¶
Despite their accuracy and speed, AI systems often lack transparency in their decision-making processes, posing risks, especially in healthcare. Therefore, building explainable AI is crucial. Since heart disease is one of the leading causes of death globally, developing AI models that can predict heart disease while providing clear explanations for their decisions is vital (Dave et al., 2020).
Recent legal trends also emphasize transparency requirements for information systems to mitigate unintended adverse effects in decision-making processes. In particular the European Union introduced the General Data Protection Regulation (GDPR) grants users the right to be informed about machine-generated decisions that affect them (Voigt & Von Dem Bussche, 2017). Consequently, patients affected by decisions made by a machine learning model may seek to understand the reasons for the decisions and outcomes that such a system has arrived at (Knapič et al., 2021b).
1.2 This Project¶
To address this issue we will utilize SHAP (Shapley Additive exPlanations ), to explain complex model decisions and answer the following research questions:
RQ1: What are the key differences in the interpretability of machine learning models between image classification and tabular data analysis when utilizing SHAP explanations?
RQ2: How does SHAP reveal the importance of different regions or structures within MRI brain scans in the decision-making process of machine learning models for brain tumor detection?
RQ3: How does SHAP analysis reveal the importance of various clinical and demographic features in the prediction of heart failure using a machine learning model?
1.2.1 Data¶
MRI Brain Scans: The dataset contains MRI Brain scans, the goal of which is to classify 4 types of tumors/no-tumor (Brain Tumor MRI Dataset, 2021).
Heart Failure:: In this dataset, 5 heart datasets are combined over 11 common features, making it the largest heart disease dataset available so far for research purposes (Heart Failure Prediction Dataset, 2021).
1.2.2 Methods¶
Transfer Learning: Transfer learning is a machine learning method where a model developed for one task is reused as the starting point for a model on a second task, offering advantages such as reduced computational costs, smaller necessary dataset sizes, and enhanced model generalizability (Murel Jacob & Kavlakoglu, 2024).
Several types of transfer learning have been proposed for medical imaging data and have been very effective, such as AlexNet, SPP-Net, VGGNet, ResNet, GoogLeNet, etc., with these models, including ResNet, demonstrating promising outcomes in the analysis of medical images (Salehi et al., 2023). For our project, we will utilize the ResNet50 architecture from TensorFlow, adapt it, and retrain it on MRI Brain scans.
Tabular Data - Neural Network:
For the tabular data, a simple artificial neural net will be built and trained using TensorFlow for binary classification.
SHAP:
SHAP is an explainable AI tool that provides a unified framework for interpreting the output of any machine learning model. It quantifies the contribution of each feature to a model’s prediction, allowing for better understanding and transparency (SHAP, n.d.) We will utilize SHAP to provide explanations for both datasets.
2. MRI Brain Scans¶
In this part of our analysis, we will classify and interpret different types of brain tumors. This is divided into 3 main parts:
Multi-Class Classification: We first train a model to classify MRI scans into four categories: glioma, meningioma, pituitary, and no tumor. We then utilize SHAP explanations to identify which regions of the brain are most relevant for distinguishing between these classes.
Binary Classification: We train a separate model to distinguish between meningioma and no tumor. This model helps us understand how the classification task changes when reduced to a binary problem and allows us to compare the SHAP explanations with the multi-class model.
Three-Class Classification: Lastly, we train a model to classify only the tumor types (glioma, meningioma, and pituitary tumor), excluding the no-tumor class. This approach helps us investigate if the SHAP explanations differ when focusing solely on different tumor types.
2.1 Multi-Class Classification¶
Prerequisits: import all of the relevant libraries
import os
import numpy as np
import pandas as pd
import cv2 as cv
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from keras.models import Sequential
from keras.layers import Dense,Flatten,Conv2D,MaxPool2D,Dropout
from sklearn.metrics import accuracy_score,confusion_matrix,classification_report
import random
import seaborn as sns
import matplotlib.pyplot as plt
import tensorflow as tf
import shap
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Flatten
from tensorflow.keras.applications import ResNet50
Setting current working directory where the MRI Brain scans have been stored in the folders Training and Testing
cwd = os.getcwd()
print("Current working directory:", cwd)
Current working directory: C:\Users\marku\Documents\UNI\Semester_4\AI_2
path_dir_tr = "C:insert_your_path\\tumor\\Training"
path_dir_te = "C:insert_your_path\\tumor\\Testing"
2.1.1 Loading and Preprocessing the MRI Brain Scnas¶
The first step invloves building the dataset for classifying different types of brain tumors from MRI scans. The 4 categories are: glioma, meningioma, notumor and pituitary.
The dataset comprises MRI images stored in corresponding folders for each category. In this notebook, we will read these images, preprocess them, and label them based on their folder names. Additionally, we will transform the images into a format suitable for training with the ResNet50 architecture.
The training and test data has already been seperated in the folders. Therefore, we will load train and test data separetly from the respective directories.
# Define the categories for classification
Categories = ["glioma","meningioma","notumor","pituitary"]
2.1.2 Data Pre-processing¶
2.1.2.1 Train Data¶
As a first step, the training data will be loadeded from the specified directory
data = []
#This function reads MRI brain scan images from specified directories, resizes them, and stores
#them along with their corresponding class labels in the 'data' list.
def create_data():
for categories in Categories:
path = os.path.join(path_dir_tr , categories)
class_name = categories
for img in os.listdir(path):
try:
img_array = cv.imread(os.path.join(path,img))
# Resize the images to 100x100 pixels
new_array = cv.resize(img_array , (100,100))
data.append([new_array , class_name])
except Exception as e:
pass
create_data()
For more efficient training we will reshuffle the training data and set a seed for better reproducability
seed_value = 777
np.random.seed(seed_value)
np.random.shuffle(data)
# Initialize empty lists to store features (images) and labels
x_data = []
y_data = []
# Iterate through the data and separate features and labels
for features,labels in data:
x_data.append(features)
y_data.append(labels)
y_data[1:10]
['notumor', 'notumor', 'notumor', 'meningioma', 'pituitary', 'glioma', 'pituitary', 'notumor', 'glioma']
df = pd.DataFrame(y_data,columns=["labels"])
df.head()
| labels | |
|---|---|
| 0 | notumor |
| 1 | notumor |
| 2 | notumor |
| 3 | notumor |
| 4 | meningioma |
Before training our model, it's important to understand the distribution of our dataset. We'll visualize how many samples we have for each category of brain tumors using a count plot.
sns.countplot(x='labels', data=df)
<Axes: xlabel='labels', ylabel='count'>
From the count plot, we can see that the dataset is quite balanced across the different categories. Next, we will transform the data into a NumPy array and encode the labels to prepare for model training.
x_data = np.array(x_data)
## mapping to integers
label_mapping = {'glioma': 0, 'meningioma': 1, 'pituitary': 2, 'notumor': 3}
## applying mapping on the targets
y_data = np.array([label_mapping[label] for label in y_data])
y_data
array([3, 3, 3, ..., 1, 3, 3])
y_data[421]
1
Looking at an MRI scan which shows meningioma tumor.
plt.imshow(x_data[421])
<matplotlib.image.AxesImage at 0x1ea489437d0>
2.1.2.2 Train Data¶
The same data loading and pre-processing steps will be done for the test data
test_data = []
#This function reads MRI brain scan images from specified directories, resizes them, and stores
#them along with their corresponding class labels in the 'data' list.
def create_data():
for categories in Categories:
path = os.path.join(path_dir_te , categories)
class_name = categories
for img in os.listdir(path):
try:
img_array = cv.imread(os.path.join(path,img))
# Resize the images to 100x100 pixels
new_array = cv.resize(img_array , (100,100))
test_data.append([new_array , class_name])
except Exception as e:
pass
create_data()
test_data[0]
[array([[[0, 0, 0],
[0, 0, 0],
[0, 0, 0],
...,
[0, 0, 0],
[0, 0, 0],
[0, 0, 0]],
[[0, 0, 0],
[0, 0, 0],
[0, 0, 0],
...,
[0, 0, 0],
[0, 0, 0],
[0, 0, 0]],
[[0, 0, 0],
[0, 0, 0],
[0, 0, 0],
...,
[0, 0, 0],
[0, 0, 0],
[0, 0, 0]],
...,
[[0, 0, 0],
[0, 0, 0],
[0, 0, 0],
...,
[0, 0, 0],
[0, 0, 0],
[0, 0, 0]],
[[0, 0, 0],
[0, 0, 0],
[0, 0, 0],
...,
[0, 0, 0],
[0, 0, 0],
[0, 0, 0]],
[[0, 0, 0],
[0, 0, 0],
[0, 0, 0],
...,
[0, 0, 0],
[0, 0, 0],
[0, 0, 0]]], dtype=uint8),
'glioma']
random.shuffle(test_data)
x_test = []
y_test = []
for features,labels in test_data:
x_test.append(features)
y_test.append(labels)
x_test = np.array(x_test)
y_test[38]
'pituitary'
plt.imshow(x_test[38])
<matplotlib.image.AxesImage at 0x1ea48980910>
## applying the same mapping for the test set
y_test = np.array([label_mapping[label] for label in y_test])
2.1.3 Transfer Learning / Model Training¶
In this section, we will utilize the ResNet50 architecture, a powerful convolutional neural network that has been pre-trained on the ImageNet (more than 14 million images) dataset. Transfer learning allows us to leverage this pre-trained model and architecture and retrain it on the MRI brain scans.
Before starting we will check if the test and train set have the required shape.
x_data.shape
(5712, 100, 100, 3)
x_test.shape
(1311, 100, 100, 3)
seed_value = 777
tf.random.set_seed(seed_value)
# Define the model architecture using Sequential
model = Sequential()
# Add ResNet50 as a base model (excluding the top layer)
resnet = tf.keras.applications.resnet50.ResNet50(include_top=False,
weights=None, # weights will be retrained (weights = "imagenet" would use the pretrained weights)
input_shape=(100,100,3),
classes=4)
model.add(resnet)
# Flatten the output of ResNet50
model.add(Flatten())
# Add a fully connected layer with 122 neurons and ReLU activation
model.add(Dense(122,activation="relu"))
# Add the output layer with 4 neurons for 4 classes and softmax activation
model.add(Dense(4,activation="softmax"))
seed_value = 777
# Compile the model with Adam optimizer, sparse categorical crossentropy loss, and accuracy metric
model.compile(optimizer="adam",loss="sparse_categorical_crossentropy",metrics=['accuracy'])
To save computation time for further processing, we will save our best performing model based on the validation accuracy in a keras file.
seed_value = 777
from tensorflow.keras.callbacks import ModelCheckpoint
# Define the checkpoint callback to save the best model based on validation accuracy
checkpoint = ModelCheckpoint('best_model_all_test.keras', monitor='val_accuracy', save_best_only=True, mode='max', verbose=1)
# Train the model with the checkpoint callback
history = model.fit(x_data, y_data, validation_data=(x_test, y_test), epochs=20, batch_size=100, callbacks=[checkpoint])
Epoch 1/20 58/58 ━━━━━━━━━━━━━━━━━━━━ 0s 4s/step - accuracy: 0.3633 - loss: 7.0578 Epoch 1: val_accuracy improved from -inf to 0.30511, saving model to best_model_all.keras 58/58 ━━━━━━━━━━━━━━━━━━━━ 274s 4s/step - accuracy: 0.3645 - loss: 6.9879 - val_accuracy: 0.3051 - val_loss: 4299.3125 Epoch 2/20 58/58 ━━━━━━━━━━━━━━━━━━━━ 0s 4s/step - accuracy: 0.6762 - loss: 0.8118 Epoch 2: val_accuracy did not improve from 0.30511 58/58 ━━━━━━━━━━━━━━━━━━━━ 232s 4s/step - accuracy: 0.6772 - loss: 0.8099 - val_accuracy: 0.2601 - val_loss: 1.6125 Epoch 3/20 58/58 ━━━━━━━━━━━━━━━━━━━━ 0s 4s/step - accuracy: 0.8412 - loss: 0.4351 Epoch 3: val_accuracy did not improve from 0.30511 58/58 ━━━━━━━━━━━━━━━━━━━━ 263s 5s/step - accuracy: 0.8416 - loss: 0.4342 - val_accuracy: 0.2632 - val_loss: 2.0032 Epoch 4/20 58/58 ━━━━━━━━━━━━━━━━━━━━ 0s 4s/step - accuracy: 0.8978 - loss: 0.2792 Epoch 4: val_accuracy improved from 0.30511 to 0.54310, saving model to best_model_all.keras 58/58 ━━━━━━━━━━━━━━━━━━━━ 263s 5s/step - accuracy: 0.8979 - loss: 0.2790 - val_accuracy: 0.5431 - val_loss: 1.3973 Epoch 5/20 58/58 ━━━━━━━━━━━━━━━━━━━━ 0s 4s/step - accuracy: 0.9157 - loss: 0.2328 Epoch 5: val_accuracy improved from 0.54310 to 0.71320, saving model to best_model_all.keras 58/58 ━━━━━━━━━━━━━━━━━━━━ 264s 5s/step - accuracy: 0.9158 - loss: 0.2324 - val_accuracy: 0.7132 - val_loss: 0.7417 Epoch 6/20 58/58 ━━━━━━━━━━━━━━━━━━━━ 0s 4s/step - accuracy: 0.9249 - loss: 0.2224 Epoch 6: val_accuracy improved from 0.71320 to 0.82685, saving model to best_model_all.keras 58/58 ━━━━━━━━━━━━━━━━━━━━ 261s 5s/step - accuracy: 0.9248 - loss: 0.2226 - val_accuracy: 0.8268 - val_loss: 0.4676 Epoch 7/20 58/58 ━━━━━━━━━━━━━━━━━━━━ 0s 4s/step - accuracy: 0.9266 - loss: 0.2126 Epoch 7: val_accuracy improved from 0.82685 to 0.86423, saving model to best_model_all.keras 58/58 ━━━━━━━━━━━━━━━━━━━━ 262s 5s/step - accuracy: 0.9268 - loss: 0.2121 - val_accuracy: 0.8642 - val_loss: 0.3722 Epoch 8/20 58/58 ━━━━━━━━━━━━━━━━━━━━ 0s 4s/step - accuracy: 0.9477 - loss: 0.1561 Epoch 8: val_accuracy did not improve from 0.86423 58/58 ━━━━━━━━━━━━━━━━━━━━ 262s 5s/step - accuracy: 0.9478 - loss: 0.1556 - val_accuracy: 0.8558 - val_loss: 0.4426 Epoch 9/20 58/58 ━━━━━━━━━━━━━━━━━━━━ 0s 4s/step - accuracy: 0.9710 - loss: 0.0915 Epoch 9: val_accuracy did not improve from 0.86423 58/58 ━━━━━━━━━━━━━━━━━━━━ 255s 4s/step - accuracy: 0.9711 - loss: 0.0912 - val_accuracy: 0.8024 - val_loss: 0.7627 Epoch 10/20 58/58 ━━━━━━━━━━━━━━━━━━━━ 0s 4s/step - accuracy: 0.9616 - loss: 0.1017 Epoch 10: val_accuracy improved from 0.86423 to 0.87872, saving model to best_model_all.keras 58/58 ━━━━━━━━━━━━━━━━━━━━ 260s 4s/step - accuracy: 0.9617 - loss: 0.1014 - val_accuracy: 0.8787 - val_loss: 0.3497 Epoch 11/20 58/58 ━━━━━━━━━━━━━━━━━━━━ 0s 4s/step - accuracy: 0.9750 - loss: 0.0675 Epoch 11: val_accuracy did not improve from 0.87872 58/58 ━━━━━━━━━━━━━━━━━━━━ 244s 4s/step - accuracy: 0.9751 - loss: 0.0673 - val_accuracy: 0.8467 - val_loss: 0.8229 Epoch 12/20 58/58 ━━━━━━━━━━━━━━━━━━━━ 0s 4s/step - accuracy: 0.9774 - loss: 0.0680 Epoch 12: val_accuracy did not improve from 0.87872 58/58 ━━━━━━━━━━━━━━━━━━━━ 240s 4s/step - accuracy: 0.9775 - loss: 0.0678 - val_accuracy: 0.8230 - val_loss: 0.6712 Epoch 13/20 58/58 ━━━━━━━━━━━━━━━━━━━━ 0s 4s/step - accuracy: 0.9902 - loss: 0.0267 Epoch 13: val_accuracy improved from 0.87872 to 0.89169, saving model to best_model_all.keras 58/58 ━━━━━━━━━━━━━━━━━━━━ 243s 4s/step - accuracy: 0.9902 - loss: 0.0267 - val_accuracy: 0.8917 - val_loss: 0.3808 Epoch 14/20 58/58 ━━━━━━━━━━━━━━━━━━━━ 0s 4s/step - accuracy: 0.9847 - loss: 0.0445 Epoch 14: val_accuracy did not improve from 0.89169 58/58 ━━━━━━━━━━━━━━━━━━━━ 242s 4s/step - accuracy: 0.9847 - loss: 0.0444 - val_accuracy: 0.8497 - val_loss: 7.2249 Epoch 15/20 58/58 ━━━━━━━━━━━━━━━━━━━━ 0s 4s/step - accuracy: 0.9537 - loss: 0.2792 Epoch 15: val_accuracy did not improve from 0.89169 58/58 ━━━━━━━━━━━━━━━━━━━━ 240s 4s/step - accuracy: 0.9528 - loss: 0.2861 - val_accuracy: 0.3089 - val_loss: 24788260.0000 Epoch 16/20 58/58 ━━━━━━━━━━━━━━━━━━━━ 0s 4s/step - accuracy: 0.7781 - loss: 0.6881 Epoch 16: val_accuracy did not improve from 0.89169 58/58 ━━━━━━━━━━━━━━━━━━━━ 247s 4s/step - accuracy: 0.7790 - loss: 0.6845 - val_accuracy: 0.3089 - val_loss: 1647.9005 Epoch 17/20 58/58 ━━━━━━━━━━━━━━━━━━━━ 0s 4s/step - accuracy: 0.8951 - loss: 0.2700 Epoch 17: val_accuracy did not improve from 0.89169 58/58 ━━━━━━━━━━━━━━━━━━━━ 247s 4s/step - accuracy: 0.8955 - loss: 0.2693 - val_accuracy: 0.5019 - val_loss: 7.5326 Epoch 18/20 58/58 ━━━━━━━━━━━━━━━━━━━━ 0s 4s/step - accuracy: 0.9568 - loss: 0.1273 Epoch 18: val_accuracy did not improve from 0.89169 58/58 ━━━━━━━━━━━━━━━━━━━━ 247s 4s/step - accuracy: 0.9568 - loss: 0.1271 - val_accuracy: 0.8795 - val_loss: 0.3459 Epoch 19/20 58/58 ━━━━━━━━━━━━━━━━━━━━ 0s 4s/step - accuracy: 0.9582 - loss: 0.1166 Epoch 19: val_accuracy improved from 0.89169 to 0.92220, saving model to best_model_all.keras 58/58 ━━━━━━━━━━━━━━━━━━━━ 243s 4s/step - accuracy: 0.9584 - loss: 0.1162 - val_accuracy: 0.9222 - val_loss: 0.2145 Epoch 20/20 58/58 ━━━━━━━━━━━━━━━━━━━━ 0s 4s/step - accuracy: 0.9726 - loss: 0.0805 Epoch 20: val_accuracy did not improve from 0.92220 58/58 ━━━━━━━━━━━━━━━━━━━━ 241s 4s/step - accuracy: 0.9727 - loss: 0.0803 - val_accuracy: 0.8825 - val_loss: 0.3575
# Load the best model
#best_model_all = tf.keras.models.load_model('best_model_all.keras')
# Compile the model
best_model_all.compile(optimizer="adam", loss="sparse_categorical_crossentropy", metrics=['accuracy'])
# Evaluate the model on the validation set
val_loss, val_accuracy = best_model_all.evaluate(x_test, y_test)
print(f"Validation Accuracy: {val_accuracy * 100:.2f}%")
41/41 ━━━━━━━━━━━━━━━━━━━━ 15s 273ms/step - accuracy: 0.9155 - loss: 0.2578 Validation Accuracy: 92.22%
Our best performing model as achieved a validation accuracy on the test set of 92%.
Generating the confusion matrix.
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
# best_model_all.predict returns probabilities for each class
predictions = best_model_all.predict(x_data)
# Convert probabilities to class labels
predicted_classes = np.argmax(predictions, axis=1)
# Compute confusion matrix
cm = confusion_matrix(y_data, predicted_classes)
# Define class names
class_names = ['glioma', 'meningioma', 'pituitary', 'notumor']
# Plot confusion matrix
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=class_names)
disp.plot(cmap='Blues')
plt.title('Confusion Matrix')
plt.show()
179/179 ━━━━━━━━━━━━━━━━━━━━ 52s 274ms/step
In the context of detecting tumors, the most important measure is not accuracy but avoiding false negatives in the class no tumor. The false negatives in the class no tumor are only 0,44%, but this is still problematic in a real-life scenario. Seven brain scans of the class meningioma have been wrongly labeled as no tumor.
2.1.4 SHAP Analysis¶
After training the model the next part will be to develop SHAP explanations. We will create SHAP explanations for all 4 classes.
# reverse mapping to get the description instead of the Integers
label_rev_mapping = { 0: 'glioma', 1 : 'meningioma', 2 : 'pituitary' , 3 : 'notumor'}
y_decoded = np.array([label_rev_mapping[label] for label in y_data])
y_decoded
array(['notumor', 'notumor', 'notumor', ..., 'meningioma', 'notumor',
'notumor'], dtype='<U10')
# Masks out image regions with blurring or inpainting.
masker = shap.maskers.Image("inpaint_telea", x_data[0].shape)
# class names the predictions
class_names = ['glioma', 'meningioma', 'pituitary', 'notumor']
# Uses Shapley values to explain the trained model
explainer = shap.Explainer(best_model_all, masker, output_names=class_names)
Note: outputs=shap.Explanation.argsort.flip[:3] will be used in the code for getting SHAP values because we want to get the top 3 most probable classes for each image i.e. top 3 classes with decreasing probability.
shap_values = explainer(
x_data[421:422], max_evals=1000, batch_size=50, outputs=shap.Explanation.argsort.flip[:3]
)
shap.image_plot(shap_values)
0%| | 0/998 [00:00<?, ?it/s]
PartitionExplainer explainer: 2it [00:24, 24.80s/it]
The predictions for our first test image are explained in the plot above. Red pixels represent positive SHAP values that increase the probability of the class, while blue pixels represent negative SHAP values that reduce the probability of the class.
The correct label:
y_decoded[421]
'meningioma'
The picture with the most red pixels is also the correct label. The red pixels are also quite clustered around the tumor which shows that the SHAP explanations work very well for this image.
We check if the model makes the correct predictions for the image used in SHAP to ensure the explanations are meaningful and accurate. This verification step confirms that our SHAP analysis highlights the most relevant brain regions for correctly classified images.
# Select the image from x_data
test_image = x_data[421]
# Reshape the image to match the input shape expected by the model
test_image = np.expand_dims(test_image, axis=0)
# Get the model's prediction
prediction = best_model_all.predict(test_image)
# Decode the prediction
predicted_class = np.argmax(prediction, axis=1)
class_names = ['glioma', 'meningioma', 'pituitary', 'notumor']
print(f"Predicted class: {class_names[predicted_class[0]]}")
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 69ms/step Predicted class: meningioma
The model made the correct prediction!
shap_values = explainer(
x_data[623:624], max_evals=2000, batch_size=50, outputs=shap.Explanation.argsort.flip[:3]
)
shap.image_plot(shap_values)
0%| | 0/1998 [00:00<?, ?it/s]
PartitionExplainer explainer: 2it [00:47, 47.75s/it]
In this plot above a glioma tumor has been sucessfully predicted, the red pixels are also quite close to the tumor in the brain scan.
y_decoded[623]
'glioma'
# Select the image from x_data
test_image = x_data[623]
# Reshape the image to match the input shape expected by the model
test_image = np.expand_dims(test_image, axis=0)
# Get the model's prediction
prediction = best_model_all.predict(test_image)
# Decode the prediction
predicted_class = np.argmax(prediction, axis=1)
class_names = ['glioma', 'meningioma', 'pituitary', 'notumor']
print(f"Predicted class: {class_names[predicted_class[0]]}")
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 51ms/step Predicted class: glioma
shap_values = explainer(
x_data[110:111], max_evals=1000, batch_size=50, outputs=shap.Explanation.argsort.flip[:3]
)
shap.image_plot(shap_values)
0%| | 0/998 [00:00<?, ?it/s]
PartitionExplainer explainer: 2it [00:30, 30.03s/it]
The shap explanations also work very well for the third type of tumor (pituitary) - the pixels in red a very strong and clustered.
y_decoded[110]
'pituitary'
# Select the image from x_data
test_image = x_data[110]
# Reshape the image to match the input shape expected by the model
test_image = np.expand_dims(test_image, axis=0)
# Get the model's prediction
prediction = best_model_all.predict(test_image)
# Decode the prediction
predicted_class = np.argmax(prediction, axis=1)
class_names = ['glioma', 'meningioma', 'pituitary', 'notumor']
print(f"Predicted class: {class_names[predicted_class[0]]}")
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 72ms/step Predicted class: pituitary
shap_values = explainer(
x_data[3:4], max_evals=1000, batch_size=50, outputs=shap.Explanation.argsort.flip[:3]
)
shap.image_plot(shap_values)
0%| | 0/998 [00:00<?, ?it/s]
PartitionExplainer explainer: 2it [00:30, 30.04s/it]
This test image shows a brain scan with notumor - the red pixels are not as clustered as with the tumor brain scans and highlight healthy regions of the brain. The model also strongly highlights an area outside of the scan as no tumor.
y_decoded[3]
'notumor'
# Select the image from x_data
test_image = x_data[3]
# Reshape the image to match the input shape expected by the model
test_image = np.expand_dims(test_image, axis=0)
# Get the model's prediction
prediction = best_model_all.predict(test_image)
# Decode the prediction
predicted_class = np.argmax(prediction, axis=1)
class_names = ['glioma', 'meningioma', 'pituitary', 'notumor']
print(f"Predicted class: {class_names[predicted_class[0]]}")
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 71ms/step Predicted class: notumor
shap_values = explainer(
x_data[207:211], max_evals=1000, batch_size=50, outputs=shap.Explanation.argsort.flip[:3]
)
shap.image_plot(shap_values)
0%| | 0/998 [00:00<?, ?it/s]
PartitionExplainer explainer: 25%|██▌ | 1/4 [00:00<?, ?it/s]
0%| | 0/998 [00:00<?, ?it/s]
PartitionExplainer explainer: 75%|███████▌ | 3/4 [01:02<00:15, 15.85s/it]
0%| | 0/998 [00:00<?, ?it/s]
PartitionExplainer explainer: 100%|██████████| 4/4 [01:34<00:00, 22.47s/it]
0%| | 0/998 [00:00<?, ?it/s]
PartitionExplainer explainer: 5it [02:03, 31.00s/it]
The MRI images in this dataset come from different perspectives, which means that the orientation and angle of the scans can vary significantly. This variability introduces an additional layer of complexity to the classification task, as the model needs to be robust enough to handle these differences in image perspectives. Nevertheless, the model and the SHAP explanations already showed that it is very robust and can MRI scans from different angles.
y_decoded[207:211]
array(['notumor', 'pituitary', 'meningioma', 'glioma'], dtype='<U10')
shap_values = explainer(
x_data[1014:1016], max_evals=1000, batch_size=50, outputs=shap.Explanation.argsort.flip[:3]
)
shap.image_plot(shap_values)
0%| | 0/998 [00:00<?, ?it/s]
PartitionExplainer explainer: 50%|█████ | 1/2 [00:00<?, ?it/s]
0%| | 0/998 [00:00<?, ?it/s]
PartitionExplainer explainer: 3it [01:00, 30.01s/it]
Although the model correctly predicts the presence of the correct tumors in these examples, the SHAP image explanations highlight pixels that do not correspond to the actual location of the tumor. This indicates that while the model performs well, it might be learning patterns that allow it to classify images correctly without accurately identifying the relevant regions in the brain scans. This underscores the necessity of improving the model and maintaining human surveillance in a real-world scenario. SHAP is very powerful in revealing these discrepancies and helps to understand the model's decision-making process. This can be used as a basis to understand how the model's decision-making process aligns with medical expertise.
#correct labels
y_decoded[1014:1016]
array(['meningioma', 'meningioma'], dtype='<U10')
# Select the image from x_data
test_image = x_data[1014]
# Reshape the image to match the input shape expected by the model
test_image = np.expand_dims(test_image, axis=0)
# Get the model's prediction
prediction = best_model_all.predict(test_image)
# Decode the prediction
predicted_class = np.argmax(prediction, axis=1)
class_names = ['glioma', 'meningioma', 'pituitary', 'notumor']
print(f"Predicted class: {class_names[predicted_class[0]]}")
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 50ms/step Predicted class: meningioma
As a final step, we examine an example where the model predicted "notumor" (class 3) but the actual diagnosis was "meningioma" (class 1). This scenario represents a critical error in real-life applications, as failing to detect a tumor can have severe consequences.
# Generate predictions
predictions = best_model_all.predict(x_data)
predicted_classes = np.argmax(predictions, axis=1)
# Find examples where the model predicted class 3 but the actual class was 1
mismatch_indexes = [
i for i, (true_label, predicted_label) in enumerate(zip(y_data, predicted_classes))
if predicted_label == 3 and true_label == 1
]
# Print the index of the first example
example_index = mismatch_indexes[0]
print(f"Example where the model predicted notumor but the actual class was meningioma: Index {example_index}")
# Display the example image
example_image = x_data[example_index]
plt.imshow(example_image, cmap='gray')
plt.title(f"Predicted: 1 (meningioma), Actual: 3 (notumor)")
plt.show()
179/179 ━━━━━━━━━━━━━━━━━━━━ 50s 278ms/step Example where the model predicted class 1 but the actual class was 3: Index 1031
shap_values = explainer(
x_data[1031:1032], max_evals=1000, batch_size=50, outputs=shap.Explanation.argsort.flip[:3]
)
shap.image_plot(shap_values)
0%| | 0/998 [00:00<?, ?it/s]
PartitionExplainer explainer: 2it [00:24, 24.78s/it]
The SHAP explanations show that even though the red pixels are densely clustered around the region of interest (the tumor), the model still predicted "notumor". Interestingly, the second most likely class, which is the correct class, "meningioma" (class 1), also shows some red pixels around the tumor region. This highlights a crucial challenge in identifying tumors in MRI scans, which are black-and-white images. The model may also sometimes learn patterns from areas outside the actual image, which are uniformly black.
y_decoded[1031]
'meningioma'
# Select the image from x_data
test_image = x_data[1031]
# Reshape the image to match the input shape expected by the model
test_image = np.expand_dims(test_image, axis=0)
# Get the model's prediction
prediction = best_model_all.predict(test_image)
# Decode the prediction
predicted_class = np.argmax(prediction, axis=1)
class_names = ['glioma', 'meningioma', 'pituitary', 'notumor']
print(f"Predicted class: {class_names[predicted_class[0]]}")
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 52ms/step Predicted class: notumor
2.2 Binary Classification (No Tumor or Meningioma)¶
Our model, trained to classify 4 different types of brain tumors, has already achieved an accuracy of 92%.
Building upon this, we now aim to refine our approach by focusing specifically on distinguishing between one specific tumor (meningioma) and no tumor cases. By retraining the ResNet50 architecture on this subset of data, we seek to further enhance the model's accuracy and efficiency in this binary classification task.
Therefore, we first need to remove the data from the other classes.
2.2.1 Data Preprocessing¶
# Define the labels to keep
labels_to_keep = ['meningioma', 'notumor']
# Create a mask to filter the data
mask = np.isin(y_decoded, labels_to_keep)
# Apply the mask to filter the data
x_data_filtered = x_data[mask]
y_data_filtered = y_data[mask]
y_data_filtered
array([3, 3, 3, ..., 1, 3, 3])
plt.imshow(x_data_filtered[96])
<matplotlib.image.AxesImage at 0x25b19aa8710>
We also need to decode the test set and get the labels.
# reverse mapping to get the description instead of the Integers
label_rev_mapping = { 0: 'glioma', 1 : 'meningioma', 2 : 'pituitary' , 3 : 'notumor'}
y_test_decoded = np.array([label_rev_mapping[label] for label in y_test])
y_test_decoded
array(['notumor', 'glioma', 'pituitary', ..., 'glioma', 'notumor',
'notumor'], dtype='<U10')
# Create a mask to filter the test data
mask_test = np.isin(y_test_decoded, labels_to_keep)
# Apply the mask to filter the test data
x_test_filtered = x_test[mask_test]
y_test_filtered = y_test[mask_test]
Mapping the data to 0 for notumor and 1 for meningioma.
#previous mapping: label_mapping = {'glioma': 0, 'meningioma': 1, 'pituitary': 2, 'notumor': 3}
##new mapping:
label_mapping2 = {1: 1, 3: 0}
##train
y_data_filtered = np.array([label_mapping2[label] for label in y_data_filtered])
##test
y_test_filtered = np.array([label_mapping2[label] for label in y_test_filtered])
plt.imshow(x_test_filtered[4])
<matplotlib.image.AxesImage at 0x2343d103690>
train_df = pd.DataFrame(y_data_filtered, columns=["labels"])
sns.countplot(x='labels', data=train_df)
<Axes: xlabel='labels', ylabel='count'>
from tensorflow.keras.utils import to_categorical
# One-hot encode y_data_filtered and y_test_filtered
y_data_filtered_one_hot = to_categorical(y_data_filtered, num_classes=2)
y_test_filtered_one_hot = to_categorical(y_test_filtered, num_classes=2)
2.2.2 Training the second model with the adapted architecture¶
seed_value = 41
tf.random.set_seed(seed_value)
model2 = Sequential()
resnet2 = tf.keras.applications.resnet50.ResNet50(include_top=False,
weights=None,
input_shape=(100,100,3),
classes=2)
model2.add(resnet2)
model2.add(Flatten())
model2.add(Dense(122,activation="relu"))
model2.add(Dense(2, activation='softmax')) # Output layer adjusted for 2 classes
seed_value = 41
model2.compile(optimizer="adam", loss="categorical_crossentropy", metrics=['accuracy'])
seed_value = 41
from tensorflow.keras.callbacks import ModelCheckpoint
# Define the checkpoint callback to save the best model based on validation accuracy
checkpoint2 = ModelCheckpoint('best_model2_test.keras', monitor='val_accuracy', save_best_only=True, mode='max', verbose=1)
# Train the model with the checkpoint callback
history2 = model2.fit(x_data_filtered, y_data_filtered_one_hot, validation_data=(x_test_filtered, y_test_filtered_one_hot), epochs=20, batch_size=100, callbacks=[checkpoint2])
Epoch 1/20 30/30 ━━━━━━━━━━━━━━━━━━━━ 0s 6s/step - accuracy: 0.5379 - loss: 6.7222 Epoch 1: val_accuracy improved from -inf to 0.56962, saving model to best_model2.keras 30/30 ━━━━━━━━━━━━━━━━━━━━ 243s 6s/step - accuracy: 0.5379 - loss: 6.6075 - val_accuracy: 0.5696 - val_loss: 25097782.0000 Epoch 2/20 30/30 ━━━━━━━━━━━━━━━━━━━━ 0s 6s/step - accuracy: 0.5488 - loss: 0.6399 Epoch 2: val_accuracy did not improve from 0.56962 30/30 ━━━━━━━━━━━━━━━━━━━━ 186s 6s/step - accuracy: 0.5486 - loss: 0.6396 - val_accuracy: 0.5696 - val_loss: 24849.8457 Epoch 3/20 30/30 ━━━━━━━━━━━━━━━━━━━━ 0s 6s/step - accuracy: 0.5488 - loss: 0.6281 Epoch 3: val_accuracy did not improve from 0.56962 30/30 ━━━━━━━━━━━━━━━━━━━━ 188s 6s/step - accuracy: 0.5486 - loss: 0.6275 - val_accuracy: 0.5696 - val_loss: 309.0490 Epoch 4/20 30/30 ━━━━━━━━━━━━━━━━━━━━ 0s 6s/step - accuracy: 0.6478 - loss: 0.5553 Epoch 4: val_accuracy improved from 0.56962 to 0.64557, saving model to best_model2.keras 30/30 ━━━━━━━━━━━━━━━━━━━━ 189s 6s/step - accuracy: 0.6499 - loss: 0.5546 - val_accuracy: 0.6456 - val_loss: 0.6475 Epoch 5/20 30/30 ━━━━━━━━━━━━━━━━━━━━ 0s 6s/step - accuracy: 0.7982 - loss: 0.4950 Epoch 5: val_accuracy did not improve from 0.64557 30/30 ━━━━━━━━━━━━━━━━━━━━ 176s 6s/step - accuracy: 0.7986 - loss: 0.4947 - val_accuracy: 0.4979 - val_loss: 0.6858 Epoch 6/20 30/30 ━━━━━━━━━━━━━━━━━━━━ 0s 6s/step - accuracy: 0.8324 - loss: 0.4535 Epoch 6: val_accuracy did not improve from 0.64557 30/30 ━━━━━━━━━━━━━━━━━━━━ 179s 6s/step - accuracy: 0.8326 - loss: 0.4533 - val_accuracy: 0.4866 - val_loss: 0.6891 Epoch 7/20 30/30 ━━━━━━━━━━━━━━━━━━━━ 0s 6s/step - accuracy: 0.8758 - loss: 0.4034 Epoch 7: val_accuracy did not improve from 0.64557 30/30 ━━━━━━━━━━━━━━━━━━━━ 176s 6s/step - accuracy: 0.8760 - loss: 0.4034 - val_accuracy: 0.4655 - val_loss: 0.7051 Epoch 8/20 30/30 ━━━━━━━━━━━━━━━━━━━━ 0s 6s/step - accuracy: 0.9037 - loss: 0.3609 Epoch 8: val_accuracy did not improve from 0.64557 30/30 ━━━━━━━━━━━━━━━━━━━━ 176s 6s/step - accuracy: 0.9040 - loss: 0.3607 - val_accuracy: 0.4543 - val_loss: 0.7139 Epoch 9/20 30/30 ━━━━━━━━━━━━━━━━━━━━ 0s 6s/step - accuracy: 0.9288 - loss: 0.3229 Epoch 9: val_accuracy did not improve from 0.64557 30/30 ━━━━━━━━━━━━━━━━━━━━ 181s 6s/step - accuracy: 0.9292 - loss: 0.3228 - val_accuracy: 0.4374 - val_loss: 0.7334 Epoch 10/20 30/30 ━━━━━━━━━━━━━━━━━━━━ 0s 6s/step - accuracy: 0.9325 - loss: 0.3110 Epoch 10: val_accuracy did not improve from 0.64557 30/30 ━━━━━━━━━━━━━━━━━━━━ 186s 6s/step - accuracy: 0.9329 - loss: 0.3107 - val_accuracy: 0.5091 - val_loss: 0.7029 Epoch 11/20 30/30 ━━━━━━━━━━━━━━━━━━━━ 0s 6s/step - accuracy: 0.9782 - loss: 0.2559 Epoch 11: val_accuracy did not improve from 0.64557 30/30 ━━━━━━━━━━━━━━━━━━━━ 182s 6s/step - accuracy: 0.9783 - loss: 0.2556 - val_accuracy: 0.4529 - val_loss: 0.7472 Epoch 12/20 30/30 ━━━━━━━━━━━━━━━━━━━━ 0s 6s/step - accuracy: 0.9879 - loss: 0.2269 Epoch 12: val_accuracy did not improve from 0.64557 30/30 ━━━━━━━━━━━━━━━━━━━━ 182s 6s/step - accuracy: 0.9878 - loss: 0.2269 - val_accuracy: 0.6076 - val_loss: 0.6232 Epoch 13/20 30/30 ━━━━━━━━━━━━━━━━━━━━ 0s 6s/step - accuracy: 0.9829 - loss: 0.2234 Epoch 13: val_accuracy improved from 0.64557 to 0.86498, saving model to best_model2.keras 30/30 ━━━━━━━━━━━━━━━━━━━━ 183s 6s/step - accuracy: 0.9828 - loss: 0.2234 - val_accuracy: 0.8650 - val_loss: 0.3720 Epoch 14/20 30/30 ━━━━━━━━━━━━━━━━━━━━ 0s 6s/step - accuracy: 0.9856 - loss: 0.2109 Epoch 14: val_accuracy did not improve from 0.86498 30/30 ━━━━━━━━━━━━━━━━━━━━ 180s 6s/step - accuracy: 0.9857 - loss: 0.2108 - val_accuracy: 0.8354 - val_loss: 0.4055 Epoch 15/20 30/30 ━━━━━━━━━━━━━━━━━━━━ 0s 6s/step - accuracy: 0.9921 - loss: 0.1912 Epoch 15: val_accuracy improved from 0.86498 to 0.91983, saving model to best_model2.keras 30/30 ━━━━━━━━━━━━━━━━━━━━ 184s 6s/step - accuracy: 0.9921 - loss: 0.1911 - val_accuracy: 0.9198 - val_loss: 0.2887 Epoch 16/20 30/30 ━━━━━━━━━━━━━━━━━━━━ 0s 6s/step - accuracy: 0.9938 - loss: 0.1776 Epoch 16: val_accuracy did not improve from 0.91983 30/30 ━━━━━━━━━━━━━━━━━━━━ 183s 6s/step - accuracy: 0.9937 - loss: 0.1777 - val_accuracy: 0.8579 - val_loss: 0.3599 Epoch 17/20 30/30 ━━━━━━━━━━━━━━━━━━━━ 0s 6s/step - accuracy: 0.9910 - loss: 0.1748 Epoch 17: val_accuracy did not improve from 0.91983 30/30 ━━━━━━━━━━━━━━━━━━━━ 179s 6s/step - accuracy: 0.9910 - loss: 0.1748 - val_accuracy: 0.8903 - val_loss: 0.3107 Epoch 18/20 30/30 ━━━━━━━━━━━━━━━━━━━━ 0s 6s/step - accuracy: 0.9749 - loss: 0.1936 Epoch 18: val_accuracy did not improve from 0.91983 30/30 ━━━━━━━━━━━━━━━━━━━━ 175s 6s/step - accuracy: 0.9748 - loss: 0.1936 - val_accuracy: 0.7018 - val_loss: 0.5444 Epoch 19/20 30/30 ━━━━━━━━━━━━━━━━━━━━ 0s 6s/step - accuracy: 0.9865 - loss: 0.1721 Epoch 19: val_accuracy did not improve from 0.91983 30/30 ━━━━━━━━━━━━━━━━━━━━ 180s 6s/step - accuracy: 0.9865 - loss: 0.1721 - val_accuracy: 0.8819 - val_loss: 0.4893 Epoch 20/20 30/30 ━━━━━━━━━━━━━━━━━━━━ 0s 6s/step - accuracy: 0.9948 - loss: 0.1505 Epoch 20: val_accuracy improved from 0.91983 to 0.92124, saving model to best_model2.keras 30/30 ━━━━━━━━━━━━━━━━━━━━ 181s 6s/step - accuracy: 0.9947 - loss: 0.1506 - val_accuracy: 0.9212 - val_loss: 0.3375
plt.plot(history2.history['accuracy'], label='Training Accuracy')
plt.plot(history2.history['val_accuracy'], label='Validation Accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()
plt.show()
In the first 10 epochs, the model showed strong signs of overfitting as the training accuracy consistently outperformed the test accuracy. However, as training progressed, the model's performance on the training and test sets got closer.
# loading the trained model
#best_model2 = tf.keras.models.load_model('best_model2.keras')
# Compile the model
best_model2.compile(optimizer="adam", loss="sparse_categorical_crossentropy", metrics=['accuracy'])
# Evaluate the model on the validation set
val_loss, val_accuracy = best_model2.evaluate(x_test_filtered, y_test_filtered)
print(f"Validation Accuracy: {val_accuracy * 100:.2f}%")
23/23 ━━━━━━━━━━━━━━━━━━━━ 11s 287ms/step - accuracy: 0.9209 - loss: 0.4258 Validation Accuracy: 92.12%
The "best_model_2" that achieved an accuracy of ~92% in classifying between notumor and meningioma.
2.2.3 SHAP Analysis¶
# reverse mapping to get the description instead of the Integers
class_names = ['notumor', 'meningioma']
masker = shap.maskers.Image("inpaint_telea", x_data_filtered[0].shape)
explainer = shap.Explainer(best_model2, masker, output_names= class_names)
shap_values = explainer(
x_data_filtered[2063:2064], max_evals=1000, batch_size=50, outputs=shap.Explanation.argsort.flip[:2]
)
shap.image_plot(shap_values)
0%| | 0/998 [00:00<?, ?it/s]
PartitionExplainer explainer: 2it [00:24, 24.04s/it]
The model performs well in terms of prediction accuracy in the binary classification task (meningioma or notumor). However, the SHAP explanations do not provide as clear insights as they did for the multi-class classification. For binary classification, the SHAP explanations consistently showed opposite pixel contributions for each class, with red pixels in one class appearing as blue in the other. This mirroring effect does not reveal as much information as it does in multi-class settings.
Additionally, the highlighted pixels in the SHAP explanations are not clustered around the correct region of the brain. Therefore, in our example, SHAP did not work as well for binary classification, highlighting the necessity of further refinement in interpretability methods for binary medical image classification.
y_data_filtered[2063]
1
# Select the images from x_data
test_images = x_data_filtered[2063]
# Reshape the images to match the input shape expected by the model
test_images = np.expand_dims(test_images, axis=0)
# Get the model's predictions
predictions = best_model_all.predict(test_images)
# Decode the predictions
predicted_classes = np.argmax(predictions, axis=1)
class_names = ['notumor', 'meningioma']
# Print the predicted classes for both images
for i, predicted_class in enumerate(predicted_classes):
print(f"Predicted class for image: {class_names[predicted_class]}")
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 87ms/step Predicted class for image 1014: meningioma
2.3 Three-Class Classification¶
as a last step we want to adapt the model to just differentiate between different types of tumors.
# Get the indices of the data where the label is not 3
tumor_indices_data = np.where(y_data != 3)[0]
tumor_indices_test = np.where(y_test != 3)[0]
# Filter the data and labels
x_data_tumor = x_data[tumor_indices_data]
y_data_tumor = y_data[tumor_indices_data]
x_test_tumor = x_test[tumor_indices_test]
y_test_tumor = y_test[tumor_indices_test]
2.3.1 Model Training¶
seed_value = 777
tf.random.set_seed(seed_value)
# Define the model architecture using Sequential
model3 = Sequential()
# Add ResNet50 as a base model (excluding the top layer)
resnet3 = ResNet50(include_top=False,
weights=None, # Weights will be retrained
input_shape=(100, 100, 3),
classes=3)
model3.add(resnet3)
# Flatten the output of ResNet50
model3.add(Flatten())
# Add a fully connected layer with 122 neurons and ReLU activation
model3.add(Dense(122, activation="relu"))
# Add the output layer with 3 neurons for 3 classes and softmax activation
model3.add(Dense(3, activation="softmax"))
# Compile the model
model3.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
from tensorflow.keras.callbacks import ModelCheckpoint
# Define the checkpoint callback to save the best model based on validation accuracy
checkpoint3 = ModelCheckpoint('best_model3_test.keras', monitor='val_accuracy', save_best_only=True, mode='max', verbose=1)
# Train the model with the checkpoint callback
history3 = model3.fit(x_data_tumor, y_data_tumor, validation_data=(x_test_tumor, y_test_tumor), epochs=20, batch_size=100, callbacks=[checkpoint3])
Epoch 1/20 42/42 ━━━━━━━━━━━━━━━━━━━━ 0s 4s/step - accuracy: 0.3448 - loss: 6.6308 Epoch 1: val_accuracy improved from -inf to 0.33775, saving model to best_model3.keras 42/42 ━━━━━━━━━━━━━━━━━━━━ 222s 4s/step - accuracy: 0.3452 - loss: 6.5726 - val_accuracy: 0.3377 - val_loss: 1081583.1250 Epoch 2/20 42/42 ━━━━━━━━━━━━━━━━━━━━ 0s 4s/step - accuracy: 0.4122 - loss: 1.8803 Epoch 2: val_accuracy improved from 0.33775 to 0.33885, saving model to best_model3.keras 42/42 ━━━━━━━━━━━━━━━━━━━━ 184s 4s/step - accuracy: 0.4128 - loss: 1.8815 - val_accuracy: 0.3389 - val_loss: 2822.8599 Epoch 3/20 42/42 ━━━━━━━━━━━━━━━━━━━━ 0s 4s/step - accuracy: 0.5257 - loss: 2.2058 Epoch 3: val_accuracy improved from 0.33885 to 0.35099, saving model to best_model3.keras 42/42 ━━━━━━━━━━━━━━━━━━━━ 188s 4s/step - accuracy: 0.5275 - loss: 2.2048 - val_accuracy: 0.3510 - val_loss: 1640.3108 Epoch 4/20 42/42 ━━━━━━━━━━━━━━━━━━━━ 0s 4s/step - accuracy: 0.7284 - loss: 1.9577 Epoch 4: val_accuracy improved from 0.35099 to 0.44260, saving model to best_model3.keras 42/42 ━━━━━━━━━━━━━━━━━━━━ 188s 4s/step - accuracy: 0.7284 - loss: 1.9508 - val_accuracy: 0.4426 - val_loss: 234.2392 Epoch 5/20 42/42 ━━━━━━━━━━━━━━━━━━━━ 0s 4s/step - accuracy: 0.7778 - loss: 2.2307 Epoch 5: val_accuracy improved from 0.44260 to 0.47903, saving model to best_model3.keras 42/42 ━━━━━━━━━━━━━━━━━━━━ 191s 5s/step - accuracy: 0.7781 - loss: 2.2120 - val_accuracy: 0.4790 - val_loss: 27.8041 Epoch 6/20 42/42 ━━━━━━━━━━━━━━━━━━━━ 0s 4s/step - accuracy: 0.8267 - loss: 1.1908 Epoch 6: val_accuracy improved from 0.47903 to 0.69095, saving model to best_model3.keras 42/42 ━━━━━━━━━━━━━━━━━━━━ 188s 4s/step - accuracy: 0.8270 - loss: 1.1838 - val_accuracy: 0.6909 - val_loss: 0.8209 Epoch 7/20 42/42 ━━━━━━━━━━━━━━━━━━━━ 0s 4s/step - accuracy: 0.8731 - loss: 1.3664 Epoch 7: val_accuracy improved from 0.69095 to 0.77594, saving model to best_model3.keras 42/42 ━━━━━━━━━━━━━━━━━━━━ 188s 4s/step - accuracy: 0.8733 - loss: 1.3600 - val_accuracy: 0.7759 - val_loss: 0.8794 Epoch 8/20 42/42 ━━━━━━━━━━━━━━━━━━━━ 0s 4s/step - accuracy: 0.9116 - loss: 0.8386 Epoch 8: val_accuracy improved from 0.77594 to 0.82781, saving model to best_model3.keras 42/42 ━━━━━━━━━━━━━━━━━━━━ 194s 5s/step - accuracy: 0.9117 - loss: 0.8370 - val_accuracy: 0.8278 - val_loss: 0.4264 Epoch 9/20 42/42 ━━━━━━━━━━━━━━━━━━━━ 0s 4s/step - accuracy: 0.9499 - loss: 0.4189 Epoch 9: val_accuracy did not improve from 0.82781 42/42 ━━━━━━━━━━━━━━━━━━━━ 187s 4s/step - accuracy: 0.9498 - loss: 0.4177 - val_accuracy: 0.8190 - val_loss: 0.4662 Epoch 10/20 42/42 ━━━━━━━━━━━━━━━━━━━━ 0s 4s/step - accuracy: 0.9711 - loss: 0.3296 Epoch 10: val_accuracy improved from 0.82781 to 0.86203, saving model to best_model3.keras 42/42 ━━━━━━━━━━━━━━━━━━━━ 181s 4s/step - accuracy: 0.9710 - loss: 0.3277 - val_accuracy: 0.8620 - val_loss: 0.4749 Epoch 11/20 42/42 ━━━━━━━━━━━━━━━━━━━━ 0s 4s/step - accuracy: 0.9353 - loss: 0.6065 Epoch 11: val_accuracy did not improve from 0.86203 42/42 ━━━━━━━━━━━━━━━━━━━━ 179s 4s/step - accuracy: 0.9353 - loss: 0.6046 - val_accuracy: 0.8455 - val_loss: 0.4104 Epoch 12/20 42/42 ━━━━━━━━━━━━━━━━━━━━ 0s 4s/step - accuracy: 0.9529 - loss: 0.2172 Epoch 12: val_accuracy improved from 0.86203 to 0.87196, saving model to best_model3.keras 42/42 ━━━━━━━━━━━━━━━━━━━━ 182s 4s/step - accuracy: 0.9532 - loss: 0.2166 - val_accuracy: 0.8720 - val_loss: 0.3331 Epoch 13/20 42/42 ━━━━━━━━━━━━━━━━━━━━ 0s 4s/step - accuracy: 0.9670 - loss: 0.2396 Epoch 13: val_accuracy did not improve from 0.87196 42/42 ━━━━━━━━━━━━━━━━━━━━ 179s 4s/step - accuracy: 0.9672 - loss: 0.2384 - val_accuracy: 0.8675 - val_loss: 0.4399 Epoch 14/20 42/42 ━━━━━━━━━━━━━━━━━━━━ 0s 4s/step - accuracy: 0.9834 - loss: 0.1245 Epoch 14: val_accuracy did not improve from 0.87196 42/42 ━━━━━━━━━━━━━━━━━━━━ 176s 4s/step - accuracy: 0.9833 - loss: 0.1246 - val_accuracy: 0.8687 - val_loss: 0.4371 Epoch 15/20 42/42 ━━━━━━━━━━━━━━━━━━━━ 0s 4s/step - accuracy: 0.9383 - loss: 0.2643 Epoch 15: val_accuracy did not improve from 0.87196 42/42 ━━━━━━━━━━━━━━━━━━━━ 170s 4s/step - accuracy: 0.9380 - loss: 0.2658 - val_accuracy: 0.5099 - val_loss: 52.3652 Epoch 16/20 42/42 ━━━━━━━━━━━━━━━━━━━━ 0s 25s/step - accuracy: 0.9495 - loss: 0.1709 Epoch 16: val_accuracy did not improve from 0.87196 42/42 ━━━━━━━━━━━━━━━━━━━━ 1032s 25s/step - accuracy: 0.9497 - loss: 0.1704 - val_accuracy: 0.8035 - val_loss: 0.9316 Epoch 17/20 42/42 ━━━━━━━━━━━━━━━━━━━━ 0s 4s/step - accuracy: 0.9750 - loss: 0.0869 Epoch 17: val_accuracy improved from 0.87196 to 0.90508, saving model to best_model3.keras 42/42 ━━━━━━━━━━━━━━━━━━━━ 174s 4s/step - accuracy: 0.9750 - loss: 0.0867 - val_accuracy: 0.9051 - val_loss: 0.2839 Epoch 18/20 42/42 ━━━━━━━━━━━━━━━━━━━━ 0s 4s/step - accuracy: 0.9791 - loss: 0.1250 Epoch 18: val_accuracy improved from 0.90508 to 0.91170, saving model to best_model3.keras 42/42 ━━━━━━━━━━━━━━━━━━━━ 183s 4s/step - accuracy: 0.9791 - loss: 0.1243 - val_accuracy: 0.9117 - val_loss: 0.2508 Epoch 19/20 42/42 ━━━━━━━━━━━━━━━━━━━━ 0s 4s/step - accuracy: 0.9845 - loss: 0.0652 Epoch 19: val_accuracy did not improve from 0.91170 42/42 ━━━━━━━━━━━━━━━━━━━━ 180s 4s/step - accuracy: 0.9845 - loss: 0.0650 - val_accuracy: 0.8598 - val_loss: 0.5471 Epoch 20/20 42/42 ━━━━━━━━━━━━━━━━━━━━ 0s 4s/step - accuracy: 0.9856 - loss: 0.0423 Epoch 20: val_accuracy did not improve from 0.91170 42/42 ━━━━━━━━━━━━━━━━━━━━ 172s 4s/step - accuracy: 0.9856 - loss: 0.0422 - val_accuracy: 0.8753 - val_loss: 0.5665
best_model3 = tf.keras.models.load_model('best_model3.keras')
# Compile the model
best_model3.compile(optimizer="adam", loss="sparse_categorical_crossentropy", metrics=['accuracy'])
# Evaluate the model on the validation set
val_loss, val_accuracy = best_model3.evaluate(x_test_tumor, y_test_tumor)
print(f"Validation Accuracy: {val_accuracy * 100:.2f}%")
29/29 ━━━━━━━━━━━━━━━━━━━━ 14s 293ms/step - accuracy: 0.9065 - loss: 0.2430 Validation Accuracy: 91.17%
The accuracy did not improve after the class notumor has been removed. We will also look at the shap explanations.
masker = shap.maskers.Image("inpaint_telea", x_data_tumor[0].shape)
class_names = ['glioma', 'meningioma', 'pituitary']
explainer = shap.Explainer(best_model3, masker, output_names= class_names)
2.3.2 SHAP Analysis¶
shap_values = explainer(
x_data_tumor[288:289], max_evals=1000, batch_size=50, outputs=shap.Explanation.argsort.flip[:3]
)
shap.image_plot(shap_values)
0%| | 0/998 [00:00<?, ?it/s]
PartitionExplainer explainer: 2it [00:22, 22.52s/it]
print(f"Correct Label: {y_data_tumor[281]}")
Correct Label: 2
There is no significant difference between the 4-class model (including "notumor") and the 3-class model (excluding "notumor") in terms of model performance and SHAP explanations.
3. Heart Disease¶
3.1 Background¶
According to the World Health Organization (WHO) Cardiovascular diseases (CVDs) are the leading cause of death globally, taking an estimated 17.9 million lives each year. CVDs are a group of disorders of the heart and blood vessels and include coronary heart disease, cerebrovascular disease, rheumatic heart disease and other conditions. More than four out of five CVD deaths are due to heart attacks and strokes, and one-third of these deaths occur prematurely in people under 70 years of age.
People with cardiovascular disease or who are at high cardiovascular risk (due to the presence of one or more risk factors such as hypertension, diabetes, hyperlipidemia, or already established disease) need early detection and management, wherein a machine learning model can be of great help.
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from sklearn.metrics import precision_recall_fscore_support
Importing the heart disease dataset:
import pandas as pd
df = pd.read_csv("insert_your_path_and_file_name.csv")
print(df)
df.info()
df
Age Sex ChestPainType RestingBP Cholesterol FastingBS RestingECG \
0 40 M ATA 140 289 0 Normal
1 49 F NAP 160 180 0 Normal
2 37 M ATA 130 283 0 ST
3 48 F ASY 138 214 0 Normal
4 54 M NAP 150 195 0 Normal
.. ... .. ... ... ... ... ...
913 45 M TA 110 264 0 Normal
914 68 M ASY 144 193 1 Normal
915 57 M ASY 130 131 0 Normal
916 57 F ATA 130 236 0 LVH
917 38 M NAP 138 175 0 Normal
MaxHR ExerciseAngina Oldpeak ST_Slope HeartDisease
0 172 N 0.0 Up 0
1 156 N 1.0 Flat 1
2 98 N 0.0 Up 0
3 108 Y 1.5 Flat 1
4 122 N 0.0 Up 0
.. ... ... ... ... ...
913 132 N 1.2 Flat 1
914 141 N 3.4 Flat 1
915 115 Y 1.2 Flat 1
916 174 N 0.0 Flat 1
917 173 N 0.0 Up 0
[918 rows x 12 columns]
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 918 entries, 0 to 917
Data columns (total 12 columns):
# Column Non-Null Count Dtype
--- ------ -------------- -----
0 Age 918 non-null int64
1 Sex 918 non-null object
2 ChestPainType 918 non-null object
3 RestingBP 918 non-null int64
4 Cholesterol 918 non-null int64
5 FastingBS 918 non-null int64
6 RestingECG 918 non-null object
7 MaxHR 918 non-null int64
8 ExerciseAngina 918 non-null object
9 Oldpeak 918 non-null float64
10 ST_Slope 918 non-null object
11 HeartDisease 918 non-null int64
dtypes: float64(1), int64(6), object(5)
memory usage: 86.2+ KB
| Age | Sex | ChestPainType | RestingBP | Cholesterol | FastingBS | RestingECG | MaxHR | ExerciseAngina | Oldpeak | ST_Slope | HeartDisease | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 40 | M | ATA | 140 | 289 | 0 | Normal | 172 | N | 0.0 | Up | 0 |
| 1 | 49 | F | NAP | 160 | 180 | 0 | Normal | 156 | N | 1.0 | Flat | 1 |
| 2 | 37 | M | ATA | 130 | 283 | 0 | ST | 98 | N | 0.0 | Up | 0 |
| 3 | 48 | F | ASY | 138 | 214 | 0 | Normal | 108 | Y | 1.5 | Flat | 1 |
| 4 | 54 | M | NAP | 150 | 195 | 0 | Normal | 122 | N | 0.0 | Up | 0 |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 913 | 45 | M | TA | 110 | 264 | 0 | Normal | 132 | N | 1.2 | Flat | 1 |
| 914 | 68 | M | ASY | 144 | 193 | 1 | Normal | 141 | N | 3.4 | Flat | 1 |
| 915 | 57 | M | ASY | 130 | 131 | 0 | Normal | 115 | Y | 1.2 | Flat | 1 |
| 916 | 57 | F | ATA | 130 | 236 | 0 | LVH | 174 | N | 0.0 | Flat | 1 |
| 917 | 38 | M | NAP | 138 | 175 | 0 | Normal | 173 | N | 0.0 | Up | 0 |
918 rows × 12 columns
Checking for duplicates:
df.duplicated().sum()
0
Checking for missing values:
nan_count = df.isna().sum()
print(nan_count)
Age 0 Sex 0 ChestPainType 0 RestingBP 0 Cholesterol 0 FastingBS 0 RestingECG 0 MaxHR 0 ExerciseAngina 0 Oldpeak 0 ST_Slope 0 HeartDisease 0 dtype: int64
Unique values:
# show unique values
df.nunique()
Age 50 Sex 2 ChestPainType 4 RestingBP 67 Cholesterol 222 FastingBS 2 RestingECG 3 MaxHR 119 ExerciseAngina 2 Oldpeak 53 ST_Slope 3 HeartDisease 2 dtype: int64
This gave us a first overview of the dataset and the data included. We saw that there are 12 attributes/features overall, including the target variable HeartDisease. Furthermore, there are no missing values or duplicates. We also got an overview of the datatypes of each attribute.
df.head()
| Age | Sex | ChestPainType | RestingBP | Cholesterol | FastingBS | RestingECG | MaxHR | ExerciseAngina | Oldpeak | ST_Slope | HeartDisease | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 40 | M | ATA | 140 | 289 | 0 | Normal | 172 | N | 0.0 | Up | 0 |
| 1 | 49 | F | NAP | 160 | 180 | 0 | Normal | 156 | N | 1.0 | Flat | 1 |
| 2 | 37 | M | ATA | 130 | 283 | 0 | ST | 98 | N | 0.0 | Up | 0 |
| 3 | 48 | F | ASY | 138 | 214 | 0 | Normal | 108 | Y | 1.5 | Flat | 1 |
| 4 | 54 | M | NAP | 150 | 195 | 0 | Normal | 122 | N | 0.0 | Up | 0 |
3.2 Features:¶
Here we can see the different feautres/attributes according to which a heart disease is predicted. Each row is representing a different patient whereas each feature represents different health parameters for a given patient.
The different features are described as follows:
1.) Age: age of the patient in years.
2.) Sex: sex of the patient (M for male and F for female)
3.) ChestPainType: chest pain type (TA: Typical Angina, ATA: Atypical Angina, NAP: Non-Anginal Pain, ASY: Asymptomatic)
4.) RestingBP: resting systolic blood pressure – the pressure when your heart pushes blood out around your body while resting (mm Hg)
5.) Cholesterol: serum cholesterol (mm/dl)
6.) FastingBS: fasting blood sugar (1: if FastingBS > 120 mg/dl, 0: otherwise)
7.) RestingECG: resting electrocardiogram results (Normal: Normal, ST: having ST-T wave abnormality (T wave inversions and/or ST elevation or depression of > 0.05 mV), LVH: showing probable or definite left ventricular hypertrophy by Estes' criteria)
8.)MaxHR: maximum heart rate achieved (Numeric value between 60 and 202)
9.) ExerciseAngina: exercise-induced angina (Y: Yes, N: No)
10.) Oldpeak: oldpeak = ST Numeric value measured in depression (ST segment depression < than 0.5 mm is accepted in all leads; ST segment depression >= 0.5 mm is considered pathological)
11.) ST_Slope: the slope of the peak exercise ST segment (Up: upsloping, Flat: flat, Down: downsloping)
Target attribute: HeartDisease: output class/ target values (1: heart disease, 0: Normal)
3.3 Exploratory Data Analysis¶
First, lets visualize all numerical and categorical features in the dataset regarding heart disease and normal resutls:
Categorical = df.select_dtypes(include=['object'])
Numerical = df.select_dtypes(include=['int64', 'float64'])
print('Categorical features:\n', Categorical)
print('Numerical features:\n', Numerical)
Categorical features:
Sex ChestPainType RestingECG ExerciseAngina ST_Slope
0 M ATA Normal N Up
1 F NAP Normal N Flat
2 M ATA ST N Up
3 F ASY Normal Y Flat
4 M NAP Normal N Up
.. .. ... ... ... ...
913 M TA Normal N Flat
914 M ASY Normal N Flat
915 M ASY Normal Y Flat
916 F ATA LVH N Flat
917 M NAP Normal N Up
[918 rows x 5 columns]
Numerical features:
Age RestingBP Cholesterol FastingBS MaxHR Oldpeak HeartDisease
0 40 140 289 0 172 0.0 0
1 49 160 180 0 156 1.0 1
2 37 130 283 0 98 0.0 0
3 48 138 214 0 108 1.5 1
4 54 150 195 0 122 0.0 0
.. ... ... ... ... ... ... ...
913 45 110 264 0 132 1.2 1
914 68 144 193 1 141 3.4 1
915 57 130 131 0 115 1.2 1
916 57 130 236 0 174 0.0 1
917 38 138 175 0 173 0.0 0
[918 rows x 7 columns]
# ploting numerical features with target
for i in Numerical:
plt.figure(figsize=(10,5))
sns.countplot(x=i, data=df, hue='HeartDisease')
plt.legend(['Normal', 'Heart Disease'])
plt.title(i)
plt.show()
#ploting categorical features with target
for i in Categorical:
plt.figure(figsize=(10,5))
sns.countplot(x=i, data=df, hue='HeartDisease', edgecolor='black')
plt.legend(['Normal', 'Heart Disease'])
plt.title(i)
plt.show()
#pairplot using target HeartDisease Column
sns.pairplot(df, hue='HeartDisease')
plt.show()
Taking a closer look at Age regarding Heart Disease and no Heart Disease:
sns.distplot(df['Age'][df['HeartDisease'] == 1], kde=True, color='red', label='Heart Disease')
sns.distplot(df['Age'][df['HeartDisease'] == 0], kde=True, color='green', label='Normal')
plt.legend()
<ipython-input-16-dd23024180f3>:1: UserWarning: `distplot` is a deprecated function and will be removed in seaborn v0.14.0. Please adapt your code to use either `displot` (a figure-level function with similar flexibility) or `histplot` (an axes-level function for histograms). For a guide to updating your code to use the new functions, please see https://gist.github.com/mwaskom/de44147ed2974457ad6372750bbe5751 sns.distplot(df['Age'][df['HeartDisease'] == 1], kde=True, color='red', label='Heart Disease') <ipython-input-16-dd23024180f3>:2: UserWarning: `distplot` is a deprecated function and will be removed in seaborn v0.14.0. Please adapt your code to use either `displot` (a figure-level function with similar flexibility) or `histplot` (an axes-level function for histograms). For a guide to updating your code to use the new functions, please see https://gist.github.com/mwaskom/de44147ed2974457ad6372750bbe5751 sns.distplot(df['Age'][df['HeartDisease'] == 0], kde=True, color='green', label='Normal')
<matplotlib.legend.Legend at 0x7cc11b393460>
df.describe().round(1)
| Age | RestingBP | Cholesterol | FastingBS | MaxHR | Oldpeak | HeartDisease | |
|---|---|---|---|---|---|---|---|
| count | 918.0 | 918.0 | 918.0 | 918.0 | 918.0 | 918.0 | 918.0 |
| mean | 53.5 | 132.4 | 198.8 | 0.2 | 136.8 | 0.9 | 0.6 |
| std | 9.4 | 18.5 | 109.4 | 0.4 | 25.5 | 1.1 | 0.5 |
| min | 28.0 | 0.0 | 0.0 | 0.0 | 60.0 | -2.6 | 0.0 |
| 25% | 47.0 | 120.0 | 173.2 | 0.0 | 120.0 | 0.0 | 0.0 |
| 50% | 54.0 | 130.0 | 223.0 | 0.0 | 138.0 | 0.6 | 1.0 |
| 75% | 60.0 | 140.0 | 267.0 | 0.0 | 156.0 | 1.5 | 1.0 |
| max | 77.0 | 200.0 | 603.0 | 1.0 | 202.0 | 6.2 | 1.0 |
Interestingly, we can see here that the minimum value for RestingBP (Resting Blood Preassure) and Cholesterol is 0, indicating that for some patients these values are 0 which is not possible.
In general, there 2 ways to deal with this issue.
- Impute with median or mean: This may be suitable if the number of 0 values is small.
- Remove the rows: If the number of rows with 0 values is very small, removing them might be reasonable.
First, check the numbe of rows where the ResingBP and/or the Cholesterol is 0.
bp_zero_count = df[df['RestingBP'] == 0].shape[0]
print(f'Number of rows where RestingBP is 0: {bp_zero_count}')
cholesterol_zero_count = df[df['Cholesterol'] == 0].shape[0]
print(f'Number of rows where Cholesterol is 0: {cholesterol_zero_count}')
Number of rows where RestingBP is 0: 1 Number of rows where Cholesterol is 0: 172
bp_zero_rows = df[df['RestingBP'] == 0]
print(bp_zero_rows)
Age Sex ChestPainType RestingBP Cholesterol FastingBS RestingECG \
449 55 M NAP 0 0 0 Normal
MaxHR ExerciseAngina Oldpeak ST_Slope HeartDisease
449 155 N 1.5 Flat 1
# Remove rows where RestingBP is 0
df = df[df['RestingBP'] != 0]
# Verify removal
print(df.shape) # Check the new number of rows and columns
(917, 12)
Replace the zero values with the mean of the non-zero values in the Cholesterol column.
cholesterol_mean = df[df['Cholesterol'] != 0]['Cholesterol'].mean()
# Impute the zero values (using mean in this example)
df['Cholesterol'].replace(0, cholesterol_mean, inplace=True)
# Verify imputation and new feature
print(df[['Cholesterol']].describe())
print(df[['RestingBP']].describe())
Cholesterol
count 917.000000
mean 244.635389
std 53.347125
min 85.000000
25% 214.000000
50% 244.635389
75% 267.000000
max 603.000000
RestingBP
count 917.000000
mean 132.540894
std 17.999749
min 80.000000
25% 120.000000
50% 130.000000
75% 140.000000
max 200.000000
<ipython-input-21-b3232ed1e23d>:4: SettingWithCopyWarning: A value is trying to be set on a copy of a slice from a DataFrame See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy df['Cholesterol'].replace(0, cholesterol_mean, inplace=True)
df.describe().round(1)
| Age | RestingBP | Cholesterol | FastingBS | MaxHR | Oldpeak | HeartDisease | |
|---|---|---|---|---|---|---|---|
| count | 917.0 | 917.0 | 917.0 | 917.0 | 917.0 | 917.0 | 917.0 |
| mean | 53.5 | 132.5 | 244.6 | 0.2 | 136.8 | 0.9 | 0.6 |
| std | 9.4 | 18.0 | 53.3 | 0.4 | 25.5 | 1.1 | 0.5 |
| min | 28.0 | 80.0 | 85.0 | 0.0 | 60.0 | -2.6 | 0.0 |
| 25% | 47.0 | 120.0 | 214.0 | 0.0 | 120.0 | 0.0 | 0.0 |
| 50% | 54.0 | 130.0 | 244.6 | 0.0 | 138.0 | 0.6 | 1.0 |
| 75% | 60.0 | 140.0 | 267.0 | 0.0 | 156.0 | 1.5 | 1.0 |
| max | 77.0 | 200.0 | 603.0 | 1.0 | 202.0 | 6.2 | 1.0 |
Next, the categorial variables are encoded with integer/label encoding:
These variables are:
Sex, ChestPainType, RestingECG, ExerciseAngina, ST_Slope
from sklearn.preprocessing import LabelEncoder
categorical_columns = ['Sex', 'ChestPainType', 'RestingECG', 'ExerciseAngina', 'ST_Slope']
# Apply label encoding to each categorical column
for col in categorical_columns:
df[col] = LabelEncoder().fit_transform(df[col])
<ipython-input-23-fb355eb84eb5>:6: SettingWithCopyWarning: A value is trying to be set on a copy of a slice from a DataFrame. Try using .loc[row_indexer,col_indexer] = value instead See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy df[col] = LabelEncoder().fit_transform(df[col]) <ipython-input-23-fb355eb84eb5>:6: SettingWithCopyWarning: A value is trying to be set on a copy of a slice from a DataFrame. Try using .loc[row_indexer,col_indexer] = value instead See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy df[col] = LabelEncoder().fit_transform(df[col]) <ipython-input-23-fb355eb84eb5>:6: SettingWithCopyWarning: A value is trying to be set on a copy of a slice from a DataFrame. Try using .loc[row_indexer,col_indexer] = value instead See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy df[col] = LabelEncoder().fit_transform(df[col]) <ipython-input-23-fb355eb84eb5>:6: SettingWithCopyWarning: A value is trying to be set on a copy of a slice from a DataFrame. Try using .loc[row_indexer,col_indexer] = value instead See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy df[col] = LabelEncoder().fit_transform(df[col]) <ipython-input-23-fb355eb84eb5>:6: SettingWithCopyWarning: A value is trying to be set on a copy of a slice from a DataFrame. Try using .loc[row_indexer,col_indexer] = value instead See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy df[col] = LabelEncoder().fit_transform(df[col])
df.head()
| Age | Sex | ChestPainType | RestingBP | Cholesterol | FastingBS | RestingECG | MaxHR | ExerciseAngina | Oldpeak | ST_Slope | HeartDisease | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 40 | 1 | 1 | 140 | 289.0 | 0 | 1 | 172 | 0 | 0.0 | 2 | 0 |
| 1 | 49 | 0 | 2 | 160 | 180.0 | 0 | 1 | 156 | 0 | 1.0 | 1 | 1 |
| 2 | 37 | 1 | 1 | 130 | 283.0 | 0 | 2 | 98 | 0 | 0.0 | 2 | 0 |
| 3 | 48 | 0 | 0 | 138 | 214.0 | 0 | 1 | 108 | 1 | 1.5 | 1 | 1 |
| 4 | 54 | 1 | 2 | 150 | 195.0 | 0 | 1 | 122 | 0 | 0.0 | 2 | 0 |
df.hist(figsize=(20,20));
Next we visualize the correlation matrix in order to gain a deeper insight into the dataset. After analysis, we find no significant correlations among input features (indicating no multicollinearity). Consequently, we decide to incorporate all of the features into our model.
df.corr()
| Age | Sex | ChestPainType | RestingBP | Cholesterol | FastingBS | RestingECG | MaxHR | ExerciseAngina | Oldpeak | ST_Slope | HeartDisease | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|
| Age | 1.000000 | 0.055670 | -0.077439 | 0.263084 | 0.053374 | 0.198170 | -0.007487 | -0.382280 | 0.216017 | 0.258563 | -0.268217 | 0.282012 |
| Sex | 0.055670 | 1.000000 | -0.127408 | 0.009427 | -0.101721 | 0.120424 | 0.071553 | -0.189668 | 0.191226 | 0.105444 | -0.150408 | 0.305118 |
| ChestPainType | -0.077439 | -0.127408 | 1.000000 | -0.011020 | -0.092763 | -0.072461 | -0.072625 | 0.288465 | -0.354027 | -0.178367 | 0.214582 | -0.388592 |
| RestingBP | 0.263084 | 0.009427 | -0.011020 | 1.000000 | 0.083076 | 0.067823 | 0.023455 | -0.109693 | 0.153064 | 0.174252 | -0.082155 | 0.117990 |
| Cholesterol | 0.053374 | -0.101721 | -0.092763 | 0.083076 | 1.000000 | 0.043015 | -0.063612 | -0.017244 | 0.077578 | 0.053039 | -0.069737 | 0.094113 |
| FastingBS | 0.198170 | 0.120424 | -0.072461 | 0.067823 | 0.043015 | 1.000000 | 0.087074 | -0.131067 | 0.059988 | 0.053062 | -0.176196 | 0.267994 |
| RestingECG | -0.007487 | 0.071553 | -0.072625 | 0.023455 | -0.063612 | 0.087074 | 1.000000 | -0.179339 | 0.077545 | -0.020452 | -0.006768 | 0.057393 |
| MaxHR | -0.382280 | -0.189668 | 0.288465 | -0.109693 | -0.017244 | -0.131067 | -0.179339 | 1.000000 | -0.370023 | -0.161213 | 0.344047 | -0.401410 |
| ExerciseAngina | 0.216017 | 0.191226 | -0.354027 | 0.153064 | 0.077578 | 0.059988 | 0.077545 | -0.370023 | 1.000000 | 0.409494 | -0.429483 | 0.495490 |
| Oldpeak | 0.258563 | 0.105444 | -0.178367 | 0.174252 | 0.053039 | 0.053062 | -0.020452 | -0.161213 | 0.409494 | 1.000000 | -0.501735 | 0.403638 |
| ST_Slope | -0.268217 | -0.150408 | 0.214582 | -0.082155 | -0.069737 | -0.176196 | -0.006768 | 0.344047 | -0.429483 | -0.501735 | 1.000000 | -0.558541 |
| HeartDisease | 0.282012 | 0.305118 | -0.388592 | 0.117990 | 0.094113 | 0.267994 | 0.057393 | -0.401410 | 0.495490 | 0.403638 | -0.558541 | 1.000000 |
plt.figure(figsize=(12, 8))
sns.heatmap(df.corr(), annot=True, cmap='coolwarm')
plt.show()
Plotting the number of observations for each target value we can see that there is an uneven distribtuion. This class imbalance can introduce bias in favor of the majority class (heart disease) and might influence the performance of machine learning models. This issue will be tackled in the next step, the data preprocessing.
plt.figure(figsize=(8, 6))
ax = sns.barplot(x=df.HeartDisease.value_counts().index, y=df.HeartDisease.value_counts(), palette="Set2")
for p in ax.patches:
height = p.get_height()
ax.annotate(
f"{height}",
(p.get_x() + p.get_width() / 2, height),
ha="center",
va="center",
xytext=(0, 7),
textcoords="offset points",
)
plt.title("Number of Observations of the Target Value")
plt.show()
<ipython-input-28-b735cce510da>:2: FutureWarning: Passing `palette` without assigning `hue` is deprecated and will be removed in v0.14.0. Assign the `x` variable to `hue` and set `legend=False` for the same effect. ax = sns.barplot(x=df.HeartDisease.value_counts().index, y=df.HeartDisease.value_counts(), palette="Set2")
3.4 Data Preprocessing¶
- Split the dataset in X (input data) and Y (target data)
- Split the data in training and test data
- Create balanced classes for training (SMOTE)
- Standardize the features
- Split the dataset in X (input data) and Y (target data)
X = df.iloc[:,:-1]
Y = pd.get_dummies(df.iloc[:,-1])
X.shape
(917, 11)
Y.shape
(917, 2)
- Split the data in training and test data
X_tr, X_te, ytr, yte = train_test_split(X, Y, test_size=0.2, shuffle=True, random_state=1)
- Create balanced classes for training (SMOTE)
X_tr.value_counts()
Age Sex ChestPainType RestingBP Cholesterol FastingBS RestingECG MaxHR ExerciseAngina Oldpeak ST_Slope
28 1 1 130 132.0 0 0 185 0 0.0 2 1
58 1 0 128 259.0 0 0 130 1 3.0 1 1
130 263.0 0 1 140 1 2.0 1 1
132 458.0 1 1 69 0 1.0 0 1
135 222.0 0 1 100 0 0.0 2 1
..
50 1 0 145 264.0 0 1 150 0 0.0 1 1
2 129 196.0 0 1 163 0 0.0 2 1
140 233.0 0 1 163 0 0.6 1 1
51 0 0 114 258.0 1 0 96 0 1.0 2 1
77 1 0 124 171.0 0 2 110 1 2.0 2 1
Name: count, Length: 733, dtype: int64
ytr.value_counts()
0 1 False True 407 True False 326 Name: count, dtype: int64
from imblearn.over_sampling import SMOTE
# Convert ytr back to single column of class labels
ytr_single_column = ytr.idxmax(axis=1)
# Apply SMOTE
smote = SMOTE(random_state=1)
X_tr, ytr_resampled_single_column = smote.fit_resample(X_tr, ytr_single_column)
# Convert ytr_resampled back to one-hot encoding
ytr= pd.get_dummies(ytr_resampled_single_column)
print(ytr.value_counts())
0 1 False True 407 True False 407 Name: count, dtype: int64
- Standardize the features
Standardization ensures that all features contribute equally to the model training by bringing them to a similar scale. This prevents features with larger ranges from dominating the learning process. It also improves the model performance as many machine learning algorithms perform better or converge faster when features are on a similar scale and centered around zero.
from sklearn.preprocessing import StandardScaler
# Initialize StandardScaler
scaler = StandardScaler()
# Fit scaler on training data and transform both training and testing data
X_tr = scaler.fit_transform(X_tr)
X_te = scaler.transform(X_te)
print("Scaled Training Data:")
print(X_tr[:5].round(2))
Scaled Training Data: [[ 0.84 0.57 -0.84 0.44 -0.66 -0.52 -1.54 0.01 1.32 1. 1.04] [-0.14 -1.77 1.28 -0.42 0.5 -0.52 0.05 0.05 -0.76 -0.8 1.04] [ 1.27 -1.77 -0.84 1.02 -0.34 -0.52 -1.54 -0.96 -0.76 0.15 -0.61] [-0.36 0.57 -0.84 0.44 -2.06 -0.52 0.05 -0.11 -0.76 -0.8 1.04] [-0.91 0.57 -0.84 -0.13 -0.45 -0.52 1.64 -0.32 1.32 0.15 -0.61]]
3.5 Model Training: ANN with one output¶
towards 0 = no heart disease; and towards 1 = heart disease with a threshold of 0.5
tf.random.set_seed(16)
np.random.seed(16)
# Define the model architecture
model = tf.keras.Sequential([
tf.keras.layers.Dense(32, input_shape=(X_tr.shape[1],), activation='relu'),
tf.keras.layers.Dense(70, activation='relu'),
tf.keras.layers.Dense(2, activation='softmax')
])
# Compile the model
model.compile(optimizer='adam',
loss=tf.keras.losses.CategoricalCrossentropy(),
metrics=['accuracy'])
# Early stopping callback based on validation accuracy
early_stopping = tf.keras.callbacks.EarlyStopping(
monitor='val_accuracy', # Monitor validation accuracy
patience=20, # Number of epochs with no improvement after which training will be stopped
restore_best_weights=True # Restore model weights from the epoch with the best value of the monitored quantity
)
# Train the model
history = model.fit(X_tr, ytr, epochs=500,
batch_size=32,
validation_data=(X_te, yte),
callbacks=[early_stopping])
# Evaluate the model on test data
test_loss, test_accuracy = model.evaluate(X_te, yte)
print(f"Final Test Accuracy: {test_accuracy:.2%}")
# Predictions
y_pred_prob = model.predict(X_te)
y_pred = np.argmax(y_pred_prob, axis=1)
# Convert one-hot encoded yte back to class labels
y_true = np.argmax(yte, axis=1)
# Calculate accuracy
accuracy = accuracy_score(y_true, y_pred)
print(f"Accuracy on test set: {accuracy:.2%}")
Epoch 1/500 26/26 [==============================] - 1s 13ms/step - loss: 0.5818 - accuracy: 0.7064 - val_loss: 0.4464 - val_accuracy: 0.8098 Epoch 2/500 26/26 [==============================] - 0s 4ms/step - loss: 0.4418 - accuracy: 0.8108 - val_loss: 0.3657 - val_accuracy: 0.8478 Epoch 3/500 26/26 [==============================] - 0s 4ms/step - loss: 0.3906 - accuracy: 0.8366 - val_loss: 0.3369 - val_accuracy: 0.8641 Epoch 4/500 26/26 [==============================] - 0s 5ms/step - loss: 0.3632 - accuracy: 0.8440 - val_loss: 0.3303 - val_accuracy: 0.8913 Epoch 5/500 26/26 [==============================] - 0s 6ms/step - loss: 0.3486 - accuracy: 0.8538 - val_loss: 0.3262 - val_accuracy: 0.8913 Epoch 6/500 26/26 [==============================] - 0s 6ms/step - loss: 0.3363 - accuracy: 0.8636 - val_loss: 0.3255 - val_accuracy: 0.8913 Epoch 7/500 26/26 [==============================] - 0s 5ms/step - loss: 0.3276 - accuracy: 0.8661 - val_loss: 0.3247 - val_accuracy: 0.8913 Epoch 8/500 26/26 [==============================] - 0s 4ms/step - loss: 0.3194 - accuracy: 0.8747 - val_loss: 0.3235 - val_accuracy: 0.8859 Epoch 9/500 26/26 [==============================] - 0s 4ms/step - loss: 0.3125 - accuracy: 0.8796 - val_loss: 0.3213 - val_accuracy: 0.8859 Epoch 10/500 26/26 [==============================] - 0s 4ms/step - loss: 0.3061 - accuracy: 0.8894 - val_loss: 0.3231 - val_accuracy: 0.8859 Epoch 11/500 26/26 [==============================] - 0s 4ms/step - loss: 0.3004 - accuracy: 0.8894 - val_loss: 0.3225 - val_accuracy: 0.8859 Epoch 12/500 26/26 [==============================] - 0s 4ms/step - loss: 0.2931 - accuracy: 0.8931 - val_loss: 0.3240 - val_accuracy: 0.8804 Epoch 13/500 26/26 [==============================] - 0s 5ms/step - loss: 0.2879 - accuracy: 0.8968 - val_loss: 0.3193 - val_accuracy: 0.8859 Epoch 14/500 26/26 [==============================] - 0s 6ms/step - loss: 0.2823 - accuracy: 0.8956 - val_loss: 0.3214 - val_accuracy: 0.8913 Epoch 15/500 26/26 [==============================] - 0s 4ms/step - loss: 0.2758 - accuracy: 0.9017 - val_loss: 0.3193 - val_accuracy: 0.8859 Epoch 16/500 26/26 [==============================] - 0s 4ms/step - loss: 0.2727 - accuracy: 0.9042 - val_loss: 0.3195 - val_accuracy: 0.8750 Epoch 17/500 26/26 [==============================] - 0s 5ms/step - loss: 0.2667 - accuracy: 0.9042 - val_loss: 0.3206 - val_accuracy: 0.8859 Epoch 18/500 26/26 [==============================] - 0s 5ms/step - loss: 0.2629 - accuracy: 0.9115 - val_loss: 0.3211 - val_accuracy: 0.8750 Epoch 19/500 26/26 [==============================] - 0s 4ms/step - loss: 0.2599 - accuracy: 0.9103 - val_loss: 0.3227 - val_accuracy: 0.8859 Epoch 20/500 26/26 [==============================] - 0s 5ms/step - loss: 0.2576 - accuracy: 0.9140 - val_loss: 0.3200 - val_accuracy: 0.8913 Epoch 21/500 26/26 [==============================] - 0s 7ms/step - loss: 0.2515 - accuracy: 0.9128 - val_loss: 0.3215 - val_accuracy: 0.8804 Epoch 22/500 26/26 [==============================] - 0s 9ms/step - loss: 0.2474 - accuracy: 0.9177 - val_loss: 0.3238 - val_accuracy: 0.8750 Epoch 23/500 26/26 [==============================] - 0s 7ms/step - loss: 0.2419 - accuracy: 0.9214 - val_loss: 0.3207 - val_accuracy: 0.8859 Epoch 24/500 26/26 [==============================] - 0s 7ms/step - loss: 0.2413 - accuracy: 0.9177 - val_loss: 0.3285 - val_accuracy: 0.8859 6/6 [==============================] - 0s 5ms/step - loss: 0.3303 - accuracy: 0.8913 Final Test Accuracy: 89.13% 6/6 [==============================] - 0s 4ms/step Accuracy on test set: 89.13%
# saving the trained model into a keras file
#model.save('heart_disease_model_1_output.keras')
Loading the saved and trained model with 89.13% test accuracy in again:
#model = tf.keras.models.load_model('heart_disease_model_1_output.keras')
X_te.shape
(184, 11)
ytr = np.argmax(ytr, axis=1)
yte = np.argmax(yte, axis=1)
Confusion matrix:
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
# Predictions
y_pred_prob = model.predict(X_te)
y_pred = (y_pred_prob > 0.5).astype(int)
# Convert one-hot encoded yte back to class labels
cm = confusion_matrix(y_true, y_pred)
disp = ConfusionMatrixDisplay(confusion_matrix=cm)
disp.plot(cmap='Blues')
plt.title('Confusion Matrix')
plt.show()
print('accuracy = ', accuracy_score(y_true, y_pred))
print('precision = ', precision_score(y_true, y_pred))
print('recall = ', recall_score(y_true, y_pred))
print('f1 score = ', f1_score(y_true, y_pred))
6/6 [==============================] - 0s 4ms/step
accuracy = 0.8913043478260869 precision = 0.9081632653061225 recall = 0.89 f1 score = 0.8989898989898989
Accuracy (0.8913): This indicates that the model correctly classified approximately 89.13% of the instances. Accuracy is the ratio of correctly predicted instances to the total instances. However, it can be misleading if the dataset is imbalanced.
Precision (0.9082): This shows that when the model predicts the positive class, it is correct approximately 90.82% of the time. Precision is the ratio of true positive predictions to the sum of true positive and false positive predictions. It is a good metric when the cost of false positives is high.
Recall (0.89): This indicates that the model correctly identifies 89% of the actual positive instances. Recall is the ratio of true positive predictions to the sum of true positive and false negative predictions. It is crucial when the cost of false negatives is high.
F1 Score (0.8990): This is the harmonic mean of precision and recall, providing a balance between the two. An F1 score of 0.8990 suggests the model maintains a good balance between precision and recall.
3.6 SHAP Globally¶
shap.initjs()
X_tr_onehot = shap.kmeans(X_tr, 10)
Computing the shap values for the test set:
explainer = shap.KernelExplainer(model.predict, X_tr_onehot)
shap_values = explainer.shap_values(X_te)
1/1 [==============================] - 0s 22ms/step
0%| | 0/184 [00:00<?, ?it/s]
1/1 [==============================] - 0s 21ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 39ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 28ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 31ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 37ms/step 640/640 [==============================] - 1s 2ms/step 1/1 [==============================] - 0s 28ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 29ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 33ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 29ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 33ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 29ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 29ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 31ms/step 640/640 [==============================] - 1s 2ms/step 1/1 [==============================] - 0s 20ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 32ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 29ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 29ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 33ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 29ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 29ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 38ms/step 640/640 [==============================] - 1s 2ms/step 1/1 [==============================] - 0s 26ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 29ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 29ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 31ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 32ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 32ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 35ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 29ms/step 640/640 [==============================] - 1s 2ms/step 1/1 [==============================] - 0s 40ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 36ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 32ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 32ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 28ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 29ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 30ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 36ms/step 640/640 [==============================] - 1s 2ms/step 1/1 [==============================] - 0s 32ms/step 640/640 [==============================] - 1s 2ms/step 1/1 [==============================] - 0s 36ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 30ms/step 640/640 [==============================] - 2s 2ms/step 1/1 [==============================] - 0s 31ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 39ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 32ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 53ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 31ms/step 640/640 [==============================] - 1s 2ms/step 1/1 [==============================] - 0s 33ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 33ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 30ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 31ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 31ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 34ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 35ms/step 640/640 [==============================] - 1s 2ms/step 1/1 [==============================] - 0s 31ms/step 640/640 [==============================] - 1s 2ms/step 1/1 [==============================] - 0s 31ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 47ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 37ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 30ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 32ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 31ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 34ms/step 640/640 [==============================] - 1s 2ms/step 1/1 [==============================] - 0s 47ms/step 640/640 [==============================] - 1s 2ms/step 1/1 [==============================] - 0s 29ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 30ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 29ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 29ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 31ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 45ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 30ms/step 640/640 [==============================] - 1s 2ms/step 1/1 [==============================] - 0s 50ms/step 640/640 [==============================] - 1s 2ms/step 1/1 [==============================] - 0s 31ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 34ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 30ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 29ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 30ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 34ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 30ms/step 640/640 [==============================] - 1s 2ms/step 1/1 [==============================] - 0s 38ms/step 640/640 [==============================] - 1s 2ms/step 1/1 [==============================] - 0s 29ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 30ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 29ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 21ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 33ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 30ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 34ms/step 640/640 [==============================] - 1s 2ms/step 1/1 [==============================] - 0s 29ms/step 640/640 [==============================] - 1s 2ms/step 1/1 [==============================] - 0s 29ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 32ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 31ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 32ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 28ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 31ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 31ms/step 640/640 [==============================] - 1s 2ms/step 1/1 [==============================] - 0s 32ms/step 640/640 [==============================] - 1s 2ms/step 1/1 [==============================] - 0s 29ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 28ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 39ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 29ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 35ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 29ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 29ms/step 640/640 [==============================] - 1s 2ms/step 1/1 [==============================] - 0s 40ms/step 640/640 [==============================] - 1s 2ms/step 1/1 [==============================] - 0s 31ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 40ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 39ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 29ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 30ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 30ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 30ms/step 640/640 [==============================] - 1s 2ms/step 1/1 [==============================] - 0s 37ms/step 640/640 [==============================] - 1s 2ms/step 1/1 [==============================] - 0s 32ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 33ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 30ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 34ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 31ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 29ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 38ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 30ms/step 640/640 [==============================] - 1s 2ms/step 1/1 [==============================] - 0s 32ms/step 640/640 [==============================] - 1s 2ms/step 1/1 [==============================] - 0s 30ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 33ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 30ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 29ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 33ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 31ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 48ms/step 640/640 [==============================] - 1s 2ms/step 1/1 [==============================] - 0s 29ms/step 640/640 [==============================] - 1s 2ms/step 1/1 [==============================] - 0s 29ms/step 640/640 [==============================] - 1s 2ms/step 1/1 [==============================] - 0s 37ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 28ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 35ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 29ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 31ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 29ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 30ms/step 640/640 [==============================] - 1s 2ms/step 1/1 [==============================] - 0s 36ms/step 640/640 [==============================] - 1s 2ms/step 1/1 [==============================] - 0s 31ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 41ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 30ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 26ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 29ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 32ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 33ms/step 640/640 [==============================] - 1s 2ms/step 1/1 [==============================] - 0s 36ms/step 640/640 [==============================] - 1s 2ms/step 1/1 [==============================] - 0s 27ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 32ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 30ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 33ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 42ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 34ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 29ms/step 640/640 [==============================] - 1s 2ms/step 1/1 [==============================] - 0s 37ms/step 640/640 [==============================] - 1s 2ms/step 1/1 [==============================] - 0s 30ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 37ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 29ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 34ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 29ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 26ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 31ms/step 640/640 [==============================] - 1s 2ms/step 1/1 [==============================] - 0s 31ms/step 640/640 [==============================] - 1s 2ms/step 1/1 [==============================] - 0s 32ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 28ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 28ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 35ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 29ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 35ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 36ms/step 640/640 [==============================] - 1s 2ms/step 1/1 [==============================] - 0s 51ms/step 640/640 [==============================] - 1s 2ms/step 1/1 [==============================] - 0s 32ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 25ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 28ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 49ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 39ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 43ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 36ms/step 640/640 [==============================] - 1s 2ms/step 1/1 [==============================] - 0s 36ms/step 640/640 [==============================] - 1s 2ms/step 1/1 [==============================] - 0s 29ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 29ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 30ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 29ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 36ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 36ms/step 640/640 [==============================] - 1s 1ms/step 1/1 [==============================] - 0s 52ms/step 640/640 [==============================] - 1s 2ms/step 1/1 [==============================] - 0s 31ms/step 640/640 [==============================] - 2s 2ms/step 1/1 [==============================] - 0s 32ms/step 640/640 [==============================] - 1s 1ms/step
shap_values.shape
(184, 11, 1)
The shape of the shap_values is (184, 11, 1).
184: corresponds to the number of samples or instances in the dataset. Specifically, it corresponds to X_tr_onehot = shap.kmeans(X_tr, 10), where only a part of the samples are taken for the shap values calculation.
11: represents the number of features.
1: This dimension is 1 since it is a binary classification problem (heart disease or no heart disease)For each feature in each sample, 1 SHAP value is calculated. Hence, for each feature, there is only one SHAP value because of the single output value of the model (towards 0 = no heart disease; towards 1 = heart disease).
Shape Interpretation: The shape (184, 11, 1) indicates that for each of the 184 samples, for each of the 11 features, 1 SHAP value has been computed (output towards 0 = no heart disease and output towards 1 = heart disease (threshold = 0.5)).
features = X.columns.tolist()
X_te_df = pd.DataFrame(X_te)
3.6.1 Beeswarm plot¶
The Beeswarm plot shows us all shap values. A shap value is computed for each feature, for each observation.
Label encoded features in detail:
- ST-Slope:
Your heart beats because of electrical signals telling it when to contract (squeeze) and relax. This process repeats in a regular cycle, keeping blood pumping through your body. An ECG records these electrical signals and shows them as a series of waves.
The ST-Segment is a short but important part of this ECG wave that can give important clues about the heart's health, especially about the oxygen supply to the heart muscle.
It always depends on the context, but in general:
2 --> up sloping = usually healthy
1 --> flat = normal but can be concerning if significantly depressed or elevated.
0 --> downsloping = usually concerning
- Excercise Angina:
Exercise Angina is pain in the chest that comes with exercise, stress, or other things that make the heart work harder. It is an extremely common symptom of coronary artery disease, which is caused by cholesterol-clogged coronary arteries.
3. **Fasting blood sugar**:
Fasting blood sugar is a common blood test to diagnose diabetes and prediabetes, which is a common risk factor for heart disease. In this case, if the FastingBS > 120 mg/dl = 1, otherwise = 0. (Label encoding)
- Oldpeak:
The Oldpeak is a value in a cardiac context representing the depression observed between the first and second parts of the ST segment in an ECG.
value < than 0.5 mm is accepted in all leads; ST segment depression >= 0.5 mm is considered pathological
- Chest Pain Type:
asy = 0 asymptomatic
ata = 1 atypical angina
nap = 2 non-anginal pain
ta = 3 typical anginal
All ohter features should be clear given the feature descriptions for the dataset in the beginning.other features should be clear,
tshap.summary_plot(shap_values[:,:,0], X_te, feature_names= features)
The beeswarm plot provides a detailed overview of how each feature influences the model's output across the entire dataset. The beeswarm plot shows each feature's overall or generalized impact on all the model's predictions across all observations. Each dot in the plot represents a shap value for the respective feature of a single observation in the dataset.
The x-axis of the plot represents the SHAP values. A positive SHAP value pushes the model's prediction towards 1 (heart disease), while a negative SHAP value pushes the prediction towards 0 (no heart disease).
The more widespread the shap values (dots) for each feature are, the higher the general influence of the feature on the model prediction in general (over the entire dataset). The narrower the shap values (dots) are distributed for each feature and center around a shap value of 0, the less impact these features have on the model prediction in general over the entire dataset.
We can see that an upsloping ST-segment (dots colored in red and with a feature value of 2 because of label encoding) are pushing the model towards a prediction for no heart disease (negative shap values). This makes sense since an upsloaping ST-segment is generally considered healthy. On the other hand, flat or downsloping ST-segments have positive shap values, contributing and pushing the model towards a prediction of a heart disease. We see that the dots (Shap values) are the most widespread for this feature, indicating that the ST-Slope feature has the greatest influence on the prediction of the model in general.
We also see that having exercise angina or not has a significant general impact on the model's prediction. Having exercise angina (red dotted positive shap values) generally pushes the model's prediction towards heart disease. Not having exercise angina (blue dotted negative shap values) pushes the model towards no heart disease prediction.
We can also see that being male pushes the model to predict a heart disease for this dataset. Since the dataset has more males with heart disease, the model might learn to associate the male gender with a higher likelihood of heart disease. The fewer observations for females and their lower incidence of heart disease means the model has less information to learn about females, potentially leading to a stronger bias towards associating males with heart disease. The model might be biased due to the imbalance in the dataset. If males are overrepresented and have a higher rate of heart disease, the model might be more sensitive to this feature and thus give it more weight in the prediction.
We can also see that the Type of Chestpain, fasting blood sugar as well as oldpeak (related to the ST-Slope) have a moderate general impact on the prediction of the model, as well as the age of the patient.
Again, the narrower the dots (shap values) for each feature, the less general impact the feature has on the model prediction.
Also, cholesterol probably does not have such a high impact in this case because, for 172 observations in the original dataset, the cholesterol levels were 0 and thus have been imputed with the mean value of all available cholesterol levels in the dataset.
Next lets have a look at the mean-shap plot. For each feature, this plot gives the absolute mean shap value across all instances (patients) in our dataset. Features that tend to make significant contributions to predictions of the model will have high mean shap values. This plot tells us which features are most important in general:
shap_values_explanation = shap.Explanation(shap_values[:,:,0],
feature_names=features,
data=X_te_df.values)
shap.plots.bar(shap_values_explanation, max_display=11)
shap.plots.beeswarm(shap_values_explanation.abs, color="shap_red")
Dependence scatter plot shows the higher the age, the higher the chance for the model to predict a heart disease:
shap.plots.scatter(shap_values_explanation[:,0])
Also, the higher the cholesterol, the higher the chance for the model to predict heart disease:
shap.plots.scatter(shap_values_explanation[:,4])
Also, the higher the oldpeak (measured in milimeters), the higher the chance for the model to predict heart disease:
shap.plots.scatter(shap_values_explanation[:,9])
shap.plots.heatmap(shap_values_explanation, max_display = 11)
<Axes: xlabel='Instances'>
3.7 Shap Locally¶
X_te.shape
(184, 11)
# Create a KernelExplainer
explainer = shap.KernelExplainer(model.predict, X_tr_onehot)
1/1 [==============================] - 0s 33ms/step
Computing shap values for the first observation (first row) or first patient to be exact from the X_te test dataset. It computes SHAP values specifically for that single observation.
shap_values_1 = explainer.shap_values(X_te[1,:])
1/1 [==============================] - 0s 23ms/step 640/640 [==============================] - 1s 1ms/step
Now that we have only computed the shape value for the first row (patient) in our X_te test set, we can see that the shap values for this observation have a shape of (11, 2), meaning that a shap value has been computed for each feature for this single observation.
shap_values_1.shape
(11, 1)
Lets see what the model predicts:
model.predict(X_te)[1,:]
6/6 [==============================] - 0s 2ms/step
array([0.9332918], dtype=float32)
We can see that the model predicted that this patient has a heart disease with a probability of around 93.3%. We can check if the model is correct by viewing the target variable for this observation/patient:
yte[1]
1
We can see that the target value (true value) for this patient is 1, indicating that this patient has indeed a heart disease. Lets have a look at the health data for this patient by converting back the standard-scaled feature to its original values:
#converting back the standard-scaled feature back to its original values
standard_scaled_values = X_te[1, :]
# Get the mean and standard deviation used during the fit
mean = scaler.mean_
scale = scaler.scale_
# Reverse the standard scaling
original_values = standard_scaled_values * scale + mean
print(original_values)
[ 50. 1. 0. 133. 218. 0. 1. 128. 1. 1.1 1. ]
Observation:
We can see that this patient is 50 years, male chestpain type asymptomatic chestpain type. The patient has a flat ST_Slope (1) and has exercise angina, meaning that he has pain in the chest that comes with exercise, stress, or other things that make the heart work harder and is an extremely common symptom of coronary artery disease, which is caused by cholesterol-clogged coronary arteries.
The patient also has a concerning oldpeak.
X.iloc[529,:]
Age 50.0 Sex 1.0 ChestPainType 0.0 RestingBP 133.0 Cholesterol 218.0 FastingBS 0.0 RestingECG 1.0 MaxHR 128.0 ExerciseAngina 1.0 Oldpeak 1.1 ST_Slope 1.0 Name: 530, dtype: float64
The following can be interpreted: the model's prediction is a heart disease with a 93.3% chance. We can also clearly see which variables push up the result or influence the decision of the model to predict a heart disease with a probability of 93.3%.
This is because we are looking from the perspective of heart disease, where a prediction towards 1 indicates heart disease and a prediction towards 0 indicates no heart disease.
We can see here that Exercise Angina has the biggest impact for this prediction followed by the ST-Slope and the fact that the patient is male. Also Chestpain type, the Oldpeak as well as the maximum heartrate of the patient contribute to this decision. However, we can also see that some features are "pushing" the model against the prediction for a 100% heart disease. The highest impact against a prediction of a 100% heart disease being the fasting blood sugar of this patient, as it is below the threshold for a concerning level.
shap.initjs()
shap.force_plot(explainer.expected_value[0], shap_values_1[:, 0], X_te[1,:], feature_names = features)
Have you run `initjs()` in this notebook? If this notebook was from another user you must also trust this notebook (File -> Trust notebook). If you are viewing this notebook on github the Javascript has been stripped for security. If you are using JupyterLab this error is because a JupyterLab extension has not yet been written.
Waterfall diagram for a more comprehensible view:
shap.initjs()
shap.plots._waterfall.waterfall_legacy(explainer.expected_value[0], shap_values_1[:, 0], X_te[1,:], feature_names = features)
Now lets try and see the explainations for a patient where the target value is 0: no heart disease and the model prediction was also 0:
shap_values_0 = explainer.shap_values(X_te[150,:])
1/1 [==============================] - 0s 34ms/step 640/640 [==============================] - 2s 2ms/step
shap_values_0.shape
(11, 1)
Lets see what the model predicts:
model.predict(X_te)[150,:]
6/6 [==============================] - 0s 5ms/step
array([0.05194689], dtype=float32)
The Model is 5% sure that it is a heart disease or in othe words, ~95% sure that it is not a heart disease. We can check in the target variables set and see that the model is correct:
yte[150]
0
#converting back the standard-scaled feature back to its original values
standard_scaled_values = X_te[150, :]
# Get the mean and standard deviation used during the fit
mean = scaler.mean_
scale = scaler.scale_
# Reverse the standard scaling
original_values = standard_scaled_values * scale + mean
print(original_values)
[ 41. 1. 1. 120. 295. 0. 1. 170. 0. 0. 2.]
Lets have a look at the health data for this single patient:
Observation:
We can see that this patient is 41 years, male, chestpain type = atypical angina. The patient has a generally healthy upsloping ST_Slope (2), and has no exercise angina, meaning that he has NO pain in the chest that comes with exercise, stress, or other things that make the heart work harder, which is an extremely common symptom of coronary artery disease, which is caused by cholesterol-clogged coronary arteries.
The patient also has no concerning oldpeak has also no concerning fasting blood sugar.
X.iloc[228,:]
Age 41.0 Sex 1.0 ChestPainType 1.0 RestingBP 120.0 Cholesterol 295.0 FastingBS 0.0 RestingECG 1.0 MaxHR 170.0 ExerciseAngina 0.0 Oldpeak 0.0 ST_Slope 2.0 Name: 228, dtype: float64
The following can be interpreted: For this patient, the model predicts heart disease at a 5% chance. We can also clearly see which variables push down the result or influence the decision of the model to these 5%.
Again, this is because we are looking from the perspective of heart disease. In general, this means that the patient has no heart disease (threshold 0.5).
In this case, the healthy ST-Slope has the most impact on the decision, followed by the fact that the patient has NO excercise angina, a healthy or unconcerning oldpeak and is also relatively young (41 years old). Also the unconcerning fasting blood sugar as well as the type of chestpain he has and the maximum heart reate are pushing the model's decision against the prediction of heart disease.
On the other hand, we also see some features of this patient pushing the model towards a prediction for a heart disease, such as the fact that the patient is male, the not so good cholesterol levels as well as his Restin blood preassure.
shap.initjs()
shap.force_plot(explainer.expected_value[0], shap_values_0[:, 0], X_te[150,:], feature_names = features)
Have you run `initjs()` in this notebook? If this notebook was from another user you must also trust this notebook (File -> Trust notebook). If you are viewing this notebook on github the Javascript has been stripped for security. If you are using JupyterLab this error is because a JupyterLab extension has not yet been written.
Waterfall diagram for a more comprehensible view:
shap.initjs()
shap.plots._waterfall.waterfall_legacy(explainer.expected_value[0], shap_values_0[:, 0], X_te[150,:], feature_names = features)
4 Qualitative Analysis and Results¶
RQ1: What are the key differences in the interpretability of machine learning models between image classification and tabular data analysis when utilizing SHAP explanations?¶
Key Differences in the Interpretability of Machine Learning Models Between Image Classification and Tabular Data Analysis Using SHAP Explanations:
Local vs. Global Interpretability:
- Image Classification: SHAP explanation can only be used locally as each image is different, providing insights on a per-image basis. Each explanation is specific to the image being analyzed, highlighting which pixels contributed most to the model's prediction.
- Tabular Data Analysis: SHAP can be used both locally and globally. Locally, it explains individual predictions by identifying important features for a specific instance. Globally, it can aggregate explanations to understand overall feature importance across the entire dataset.
Data Characteristics:
- Image Data: The underlying data consists of pixels, which are spatially structured and often contain complex patterns that the model must learn to recognize.
- Tabular Data: The data consists of features that can vary widely in type and scale. Features are typically independent of each other, and their relationships are not as visually intuitive as pixel arrangements.
Interpretation:
- Image Data: SHAP explanations for very are highly intuitive, as highlighted regions can be directly visualized, making it easy to see what parts of the image influenced the model's decision.
- Tabular Data: Interpreting SHAP explanations for tabular data often requires more domain knowledge. Understanding the impact of different features on the model's predictions involves comprehending the underlying data and the context of the features.
RQ2: How does SHAP reveal the importance of different regions or structures within MRI brain scans in the decision-making process of machine learning models for brain tumor detection?¶
The SHAP explanations have proven to be a powerful tool for interpreting image data from MRI brain scans, despite challenges such as variations in imaging angles. Particularly, when the tumor region is distinctly visible, SHAP explanations consistently deliver interpretable results and provide valuable insights into the model's decision-making process.
- The model effectively differentiates between four classes of brain tumors.
- The SHAP explanations can accurately reveal the tumor regions within the MRI brain scans by visualizing the model's decision-making.
- SHAP analyses also highlighted regions not readily apparent to non-specialists, which can also assist medical professionals in tumor detection.
- Removing the "notumor" class did not improve model performance.
- SHAP explanations enhance interpretability but do not always align with model accuracy, sometimes highlighting incorrect regions.
- Binary classification did not yield optimal SHAP results for distinguishing unhealthy regions despite accurate predictions, indicating complexity in interpreting SHAP for binary health assessments.
RQ3: How does SHAP analysis reveal the importance of various clinical and demographic features in the prediction of heart failure using machine learning models?¶
- The simple ANN, which we build ourselves with 1 hidden layer, can determine whether a patient has heart disease with 89% accuracy based on the 11 characteristics (risk factors) present in the dataset for each patient.
- The global SHAP explanations show which features or risk factors have the highest influence (positive and negative) on the model predictions generalized across all observations. The ranking of the features based on their influence is also very reasonable, considering that the most influential is the ST-Slope which describes a part of the heartbeat curve measured in an ECG, and Excercise Angina in 2nd place, which describes chest pain occurring during exercise or stress followed by sex, chestpain type, fasting blood sugar oldpeak, and age.
- The reason why cholesterol does not have such a high general impact might be because for 172 observations the cholesterol levels in the dataset were 0, and have been imputed with the mean value of all available cholesterol levels in the dataset. However, when SHAP is applied locally to patients for whom this value was not imputed, the SHAP explanations indeed show that cholesterol can have a high impact on the model decision if the cholesterol value deviates strongly from the imputed mean value.
- Locally applied on single patients (observations), the shap explanations reveal to what extent which features, given the feature values, contribute to the model prediction (probability value) for a respective patient. For a given patient, they show to what extent which features "push" towards a higher prediction for a heart disease and which push against a higher, towards a lower probability value for a heart disease, based on the feature values for that particular patient.
- Locally applied, for very clear predictions, the shap explanations are well reasonable and clear to interpret. However, for model prediction probabilities close to 0.5, the explanations become less clear and comprehensible, reflecting the need for human oversight by medical professionals.
- Nevertheless, given the model's recall score of 0.89, healthcare professionals should always review or cross-check each individual prediction/decision of the model for each patient, as there is still the possibility of false negative predictions.
5. Summary¶
The model to classify tumors performed very well, demonstrating high accuracy in its predictions. SHAP explanations are instrumental in understanding the model's decisions by highlighting the important regions of the brain. SHAP not only reveals that the model can identify the correct regions of the brain but also uncovers instances where the model learns patterns that allow it to make correct predictions without actually identifying the correct regions of the brain. SHAP explainations have also proven to be powerful for the interpretation of model decisions on tabular data, such as in this case of patient data including 11 features for predicting a heart disease. In the context of tabular data, SHAP not only provides local explanations for model decisions for a single patient (observation) but, in this case, can also be used for global (general) explanations for model decisions across the entire data set.
All in all, SHAP brings much-needed explainability into the decision-making process. This transparency is of utmost importance, as unexplainable decisions are not useful in real-life medical scenarios. Ensuring that medical professionals can trust and understand AI-driven diagnoses is crucial for safe and effective patient care.
6. References¶
Brain Tumor MRI dataset. (2021, September 24). Kaggle. https://www.kaggle.com/datasets/masoudnickparvar/brain-tumor-mri-dataset
Chander, A., Wang, J., Srinivasan, R., Uchino, K., & Chelian, S. (2018). Working with Beliefs: AI Transparency in the Enterprise [Journal-article]. http://ceur-ws.org/Vol-2068/exss14.pdf
Heart Failure Prediction Dataset. (2021, September 10). Kaggle. https://www.kaggle.com/datasets/fedesoriano/heart-failure-prediction/data
Holzinger, A., Biemann, C., Pattichis, C. S., & Kell, D. B. (2017). What do we need to build explainable AI systems for the medical domain? arXiv (Cornell University). https://doi.org/10.48550/arxiv.1712.09923
Knapič, S., Malhi, A., Saluja, R., & Främling, K. (2021). Explainable artificial intelligence for human decision support system in the medical domain. Machine Learning and Knowledge Extraction, 3(3), 740–770. https://doi.org/10.3390/make3030037
Meske, C., & Bunde, E. (2020). Transparency and Trust in Human-AI-Interaction: The role of Model-Agnostic Explanations in Computer Vision-Based Decision Support. In Lecture notes in computer science (pp. 54–69). https://doi.org/10.1007/978-3-030-50334-5_4
Murel Jacob, & Kavlakoglu, E. (2024, February 12). What is transfer learning? | IBM. IBM. https://www.ibm.com/topics/transfer-learning
Salehi, A. W., Khan, S., Gupta, G., Alabduallah, B. I., Almjally, A., Alsolai, H., Siddiqui, T., & Mellit, A. (2023). A study of CNN and transfer learning in Medical imaging: Advantages, challenges, future scope. Sustainability, 15(7), 5930. https://doi.org/10.3390/su15075930
Voigt, P., & Von Dem Bussche, A. (2017). The EU General Data Protection Regulation (GDPR). In Springer eBooks. https://doi.org/10.1007/978-3-319-57959-7
Welcome to the SHAP documentation — SHAP latest documentation. (n.d.). https://shap.readthedocs.io/en/latest/index.html